interactive.go 18.9 KB
Newer Older
1
2
3
package cmd

import (
Michael Yang's avatar
Michael Yang committed
4
	"cmp"
5
6
7
8
9
	"errors"
	"fmt"
	"io"
	"net/http"
	"os"
10
	"path/filepath"
11
	"regexp"
12
	"slices"
13
14
15
16
	"strings"

	"github.com/spf13/cobra"

17
	"github.com/ollama/ollama/api"
18
	"github.com/ollama/ollama/envconfig"
19
	"github.com/ollama/ollama/readline"
20
	"github.com/ollama/ollama/types/errtypes"
21
	"github.com/ollama/ollama/types/model"
22
23
24
25
26
27
28
29
30
31
)

type MultilineState int

const (
	MultilineNone MultilineState = iota
	MultilinePrompt
	MultilineSystem
)

32
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
33
34
	usage := func() {
		fmt.Fprintln(os.Stderr, "Available Commands:")
35
36
37
38
		fmt.Fprintln(os.Stderr, "  /set            Set session variables")
		fmt.Fprintln(os.Stderr, "  /show           Show model information")
		fmt.Fprintln(os.Stderr, "  /load <model>   Load a session or model")
		fmt.Fprintln(os.Stderr, "  /save <model>   Save your current session")
Bryce Reitano's avatar
Bryce Reitano committed
39
		fmt.Fprintln(os.Stderr, "  /clear          Clear session context")
40
41
42
		fmt.Fprintln(os.Stderr, "  /bye            Exit")
		fmt.Fprintln(os.Stderr, "  /?, /help       Help for a command")
		fmt.Fprintln(os.Stderr, "  /? shortcuts    Help for keyboard shortcuts")
43
44
		fmt.Fprintln(os.Stderr, "")
		fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
45
46

		if opts.MultiModal {
47
			fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
48
49
		}

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
		fmt.Fprintln(os.Stderr, "")
	}

	usageSet := func() {
		fmt.Fprintln(os.Stderr, "Available Commands:")
		fmt.Fprintln(os.Stderr, "  /set parameter ...     Set a parameter")
		fmt.Fprintln(os.Stderr, "  /set system <string>   Set system message")
		fmt.Fprintln(os.Stderr, "  /set history           Enable history")
		fmt.Fprintln(os.Stderr, "  /set nohistory         Disable history")
		fmt.Fprintln(os.Stderr, "  /set wordwrap          Enable wordwrap")
		fmt.Fprintln(os.Stderr, "  /set nowordwrap        Disable wordwrap")
		fmt.Fprintln(os.Stderr, "  /set format json       Enable JSON mode")
		fmt.Fprintln(os.Stderr, "  /set noformat          Disable formatting")
		fmt.Fprintln(os.Stderr, "  /set verbose           Show LLM stats")
		fmt.Fprintln(os.Stderr, "  /set quiet             Disable LLM stats")
65
66
		fmt.Fprintln(os.Stderr, "  /set think             Enable thinking")
		fmt.Fprintln(os.Stderr, "  /set nothink           Disable thinking")
67
68
69
70
71
72
73
74
75
76
77
		fmt.Fprintln(os.Stderr, "")
	}

	usageShortcuts := func() {
		fmt.Fprintln(os.Stderr, "Available keyboard shortcuts:")
		fmt.Fprintln(os.Stderr, "  Ctrl + a            Move to the beginning of the line (Home)")
		fmt.Fprintln(os.Stderr, "  Ctrl + e            Move to the end of the line (End)")
		fmt.Fprintln(os.Stderr, "   Alt + b            Move back (left) one word")
		fmt.Fprintln(os.Stderr, "   Alt + f            Move forward (right) one word")
		fmt.Fprintln(os.Stderr, "  Ctrl + k            Delete the sentence after the cursor")
		fmt.Fprintln(os.Stderr, "  Ctrl + u            Delete the sentence before the cursor")
Josh Yan's avatar
Josh Yan committed
78
		fmt.Fprintln(os.Stderr, "  Ctrl + w            Delete the word before the cursor")
79
80
81
82
83
84
85
86
87
		fmt.Fprintln(os.Stderr, "")
		fmt.Fprintln(os.Stderr, "  Ctrl + l            Clear the screen")
		fmt.Fprintln(os.Stderr, "  Ctrl + c            Stop the model from responding")
		fmt.Fprintln(os.Stderr, "  Ctrl + d            Exit ollama (/bye)")
		fmt.Fprintln(os.Stderr, "")
	}

	usageShow := func() {
		fmt.Fprintln(os.Stderr, "Available Commands:")
88
		fmt.Fprintln(os.Stderr, "  /show info         Show details for this model")
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
		fmt.Fprintln(os.Stderr, "  /show license      Show model license")
		fmt.Fprintln(os.Stderr, "  /show modelfile    Show Modelfile for this model")
		fmt.Fprintln(os.Stderr, "  /show parameters   Show parameters for this model")
		fmt.Fprintln(os.Stderr, "  /show system       Show system message")
		fmt.Fprintln(os.Stderr, "  /show template     Show prompt template")
		fmt.Fprintln(os.Stderr, "")
	}

	// only list out the most common parameters
	usageParameters := func() {
		fmt.Fprintln(os.Stderr, "Available Parameters:")
		fmt.Fprintln(os.Stderr, "  /set parameter seed <int>             Random number seed")
		fmt.Fprintln(os.Stderr, "  /set parameter num_predict <int>      Max number of tokens to predict")
		fmt.Fprintln(os.Stderr, "  /set parameter top_k <int>            Pick from top k num of tokens")
		fmt.Fprintln(os.Stderr, "  /set parameter top_p <float>          Pick token based on sum of probabilities")
104
		fmt.Fprintln(os.Stderr, "  /set parameter min_p <float>          Pick token based on top token probability * min_p")
105
106
107
108
109
		fmt.Fprintln(os.Stderr, "  /set parameter num_ctx <int>          Set the context size")
		fmt.Fprintln(os.Stderr, "  /set parameter temperature <float>    Set creativity level")
		fmt.Fprintln(os.Stderr, "  /set parameter repeat_penalty <float> How strongly to penalize repetitions")
		fmt.Fprintln(os.Stderr, "  /set parameter repeat_last_n <int>    Set how far back to look for repetitions")
		fmt.Fprintln(os.Stderr, "  /set parameter num_gpu <int>          The number of layers to send to the GPU")
110
		fmt.Fprintln(os.Stderr, "  /set parameter stop <string> <string> ...   Set the stop parameters")
111
112
113
114
115
116
117
118
119
120
121
122
123
		fmt.Fprintln(os.Stderr, "")
	}

	scanner, err := readline.New(readline.Prompt{
		Prompt:         ">>> ",
		AltPrompt:      "... ",
		Placeholder:    "Send a message (/? for help)",
		AltPlaceholder: `Use """ to end multi-line input`,
	})
	if err != nil {
		return err
	}

Michael Yang's avatar
bool  
Michael Yang committed
124
	if envconfig.NoHistory() {
125
126
127
		scanner.HistoryDisable()
	}

128
129
130
	fmt.Print(readline.StartBracketedPaste)
	defer fmt.Printf(readline.EndBracketedPaste)

131
	var sb strings.Builder
132
	var multiline MultilineState
133
	var thinkExplicitlySet bool = opts.Think != nil
134
135
136
137
138
139
140
141
142
143
144
145
146

	for {
		line, err := scanner.Readline()
		switch {
		case errors.Is(err, io.EOF):
			fmt.Println()
			return nil
		case errors.Is(err, readline.ErrInterrupt):
			if line == "" {
				fmt.Println("\nUse Ctrl + d or /bye to exit.")
			}

			scanner.Prompt.UseAlt = false
147
			sb.Reset()
148
149
150
151
152
153
154

			continue
		case err != nil:
			return err
		}

		switch {
155
156
157
158
159
160
		case multiline != MultilineNone:
			// check if there's a multiline terminating string
			before, ok := strings.CutSuffix(line, `"""`)
			sb.WriteString(before)
			if !ok {
				fmt.Fprintln(&sb)
161
162
163
164
165
				continue
			}

			switch multiline {
			case MultilineSystem:
166
				opts.System = sb.String()
167
				opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
168
				fmt.Println("Set system message.")
169
				sb.Reset()
170
			}
171

172
			multiline = MultilineNone
173
174
175
176
177
178
179
180
181
182
183
			scanner.Prompt.UseAlt = false
		case strings.HasPrefix(line, `"""`):
			line := strings.TrimPrefix(line, `"""`)
			line, ok := strings.CutSuffix(line, `"""`)
			sb.WriteString(line)
			if !ok {
				// no multiline terminating string; need more input
				fmt.Fprintln(&sb)
				multiline = MultilinePrompt
				scanner.Prompt.UseAlt = true
			}
184
		case scanner.Pasting:
185
			fmt.Fprintln(&sb, line)
186
187
188
189
190
191
			continue
		case strings.HasPrefix(line, "/list"):
			args := strings.Fields(line)
			if err := ListHandler(cmd, args[1:]); err != nil {
				return err
			}
192
193
194
195
196
197
		case strings.HasPrefix(line, "/load"):
			args := strings.Fields(line)
			if len(args) != 2 {
				fmt.Println("Usage:\n  /load <modelname>")
				continue
			}
198
199
			origOpts := opts.Copy()

200
201
202
			opts.Model = args[1]
			opts.Messages = []api.Message{}
			fmt.Printf("Loading model '%s'\n", opts.Model)
203
204
			opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
			if err != nil {
205
206
207
208
209
				if strings.Contains(err.Error(), "not found") {
					fmt.Printf("Couldn't find model '%s'\n", opts.Model)
					opts = origOpts.Copy()
					continue
				}
210
211
				return err
			}
Patrick Devine's avatar
Patrick Devine committed
212
			if err := loadOrUnloadModel(cmd, &opts); err != nil {
213
				if strings.Contains(err.Error(), "not found") {
214
215
					fmt.Printf("Couldn't find model '%s'\n", opts.Model)
					opts = origOpts.Copy()
216
217
					continue
				}
218
219
220
221
				if strings.Contains(err.Error(), "does not support thinking") {
					fmt.Printf("error: %v\n", err)
					continue
				}
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
				return err
			}
			continue
		case strings.HasPrefix(line, "/save"):
			args := strings.Fields(line)
			if len(args) != 2 {
				fmt.Println("Usage:\n  /save <modelname>")
				continue
			}

			client, err := api.ClientFromEnvironment()
			if err != nil {
				fmt.Println("error: couldn't connect to ollama server")
				return err
			}

238
			req := NewCreateRequest(args[1], opts)
239
240
241
			fn := func(resp api.ProgressResponse) error { return nil }
			err = client.Create(cmd.Context(), req, fn)
			if err != nil {
242
243
244
245
				if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
					fmt.Printf("error: The model name '%s' is invalid\n", args[1])
					continue
				}
246
247
248
249
				return err
			}
			fmt.Printf("Created new model '%s'\n", args[1])
			continue
Bryce Reitano's avatar
Bryce Reitano committed
250
251
		case strings.HasPrefix(line, "/clear"):
			opts.Messages = []api.Message{}
Patrick Devine's avatar
Patrick Devine committed
252
253
254
255
			if opts.System != "" {
				newMessage := api.Message{Role: "system", Content: opts.System}
				opts.Messages = append(opts.Messages, newMessage)
			}
Bryce Reitano's avatar
Bryce Reitano committed
256
257
			fmt.Println("Cleared session context")
			continue
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
		case strings.HasPrefix(line, "/set"):
			args := strings.Fields(line)
			if len(args) > 1 {
				switch args[1] {
				case "history":
					scanner.HistoryEnable()
				case "nohistory":
					scanner.HistoryDisable()
				case "wordwrap":
					opts.WordWrap = true
					fmt.Println("Set 'wordwrap' mode.")
				case "nowordwrap":
					opts.WordWrap = false
					fmt.Println("Set 'nowordwrap' mode.")
				case "verbose":
273
274
275
					if err := cmd.Flags().Set("verbose", "true"); err != nil {
						return err
					}
276
277
					fmt.Println("Set 'verbose' mode.")
				case "quiet":
278
279
280
					if err := cmd.Flags().Set("verbose", "false"); err != nil {
						return err
					}
281
					fmt.Println("Set 'quiet' mode.")
282
				case "think":
Michael Yang's avatar
Michael Yang committed
283
284
285
286
287
288
289
290
291
292
293
294
					thinkValue := api.ThinkValue{Value: true}
					var maybeLevel string
					if len(args) > 2 {
						maybeLevel = args[2]
					}
					if maybeLevel != "" {
						// TODO(drifkin): validate the level, could be model dependent
						// though... It will also be validated on the server once a call is
						// made.
						thinkValue.Value = maybeLevel
					}
					opts.Think = &thinkValue
295
296
297
298
					thinkExplicitlySet = true
					if client, err := api.ClientFromEnvironment(); err == nil {
						ensureThinkingSupport(cmd.Context(), client, opts.Model)
					}
Michael Yang's avatar
Michael Yang committed
299
300
301
302
303
					if maybeLevel != "" {
						fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
					} else {
						fmt.Println("Set 'think' mode.")
					}
304
				case "nothink":
Michael Yang's avatar
Michael Yang committed
305
					opts.Think = &api.ThinkValue{Value: false}
306
307
308
309
310
					thinkExplicitlySet = true
					if client, err := api.ClientFromEnvironment(); err == nil {
						ensureThinkingSupport(cmd.Context(), client, opts.Model)
					}
					fmt.Println("Set 'nothink' mode.")
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
				case "format":
					if len(args) < 3 || args[2] != "json" {
						fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
					} else {
						opts.Format = args[2]
						fmt.Printf("Set format to '%s' mode.\n", args[2])
					}
				case "noformat":
					opts.Format = ""
					fmt.Println("Disabled format.")
				case "parameter":
					if len(args) < 4 {
						usageParameters()
						continue
					}
Michael Yang's avatar
Michael Yang committed
326
					params := args[3:]
327
328
					fp, err := api.FormatParams(map[string][]string{args[2]: params})
					if err != nil {
329
						fmt.Printf("Couldn't set parameter: %q\n", err)
330
331
						continue
					}
332
					fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
333
					opts.Options[args[2]] = fp[args[2]]
Patrick Devine's avatar
Patrick Devine committed
334
				case "system":
335
336
337
338
					if len(args) < 3 {
						usageSet()
						continue
					}
339

Patrick Devine's avatar
Patrick Devine committed
340
					multiline = MultilineSystem
341

342
					line := strings.Join(args[2:], " ")
343
344
345
					line, ok := strings.CutPrefix(line, `"""`)
					if !ok {
						multiline = MultilineNone
346
					} else {
347
348
349
350
351
352
353
354
355
356
357
358
359
						// only cut suffix if the line is multiline
						line, ok = strings.CutSuffix(line, `"""`)
						if ok {
							multiline = MultilineNone
						}
					}

					sb.WriteString(line)
					if multiline != MultilineNone {
						scanner.Prompt.UseAlt = true
						continue
					}

Patrick Devine's avatar
Patrick Devine committed
360
361
362
363
364
365
366
367
					opts.System = sb.String() // for display in modelfile
					newMessage := api.Message{Role: "system", Content: sb.String()}
					// Check if the slice is not empty and the last message is from 'system'
					if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
						// Replace the last message
						opts.Messages[len(opts.Messages)-1] = newMessage
					} else {
						opts.Messages = append(opts.Messages, newMessage)
368
					}
Patrick Devine's avatar
Patrick Devine committed
369
					fmt.Println("Set system message.")
370
371
					sb.Reset()
					continue
372
373
374
375
376
377
378
379
380
381
382
383
384
385
				default:
					fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
				}
			} else {
				usageSet()
			}
		case strings.HasPrefix(line, "/show"):
			args := strings.Fields(line)
			if len(args) > 1 {
				client, err := api.ClientFromEnvironment()
				if err != nil {
					fmt.Println("error: couldn't connect to ollama server")
					return err
				}
386
				req := &api.ShowRequest{
Michael Yang's avatar
Michael Yang committed
387
388
389
					Name:    opts.Model,
					System:  opts.System,
					Options: opts.Options,
390
391
				}
				resp, err := client.Show(cmd.Context(), req)
392
393
394
395
396
397
				if err != nil {
					fmt.Println("error: couldn't get model")
					return err
				}

				switch args[1] {
398
				case "info":
399
					_ = showInfo(resp, false, os.Stderr)
400
401
				case "license":
					if resp.License == "" {
402
						fmt.Println("No license was specified for this model.")
403
404
405
406
407
408
					} else {
						fmt.Println(resp.License)
					}
				case "modelfile":
					fmt.Println(resp.Modelfile)
				case "parameters":
Patrick Devine's avatar
Patrick Devine committed
409
					fmt.Println("Model defined parameters:")
410
					if resp.Parameters == "" {
Patrick Devine's avatar
Patrick Devine committed
411
						fmt.Println("  No additional parameters were specified for this model.")
412
					} else {
Patrick Devine's avatar
Patrick Devine committed
413
414
						for _, l := range strings.Split(resp.Parameters, "\n") {
							fmt.Printf("  %s\n", l)
415
						}
Patrick Devine's avatar
Patrick Devine committed
416
417
418
419
420
421
422
423
					}
					fmt.Println()
					if len(opts.Options) > 0 {
						fmt.Println("User defined parameters:")
						for k, v := range opts.Options {
							fmt.Printf("  %-*s %v\n", 30, k, v)
						}
						fmt.Println()
424
425
426
427
428
429
430
431
					}
				case "system":
					switch {
					case opts.System != "":
						fmt.Println(opts.System + "\n")
					case resp.System != "":
						fmt.Println(resp.System + "\n")
					default:
432
						fmt.Println("No system message was specified for this model.")
433
434
					}
				case "template":
Patrick Devine's avatar
Patrick Devine committed
435
					if resp.Template != "" {
436
						fmt.Println(resp.Template)
Patrick Devine's avatar
Patrick Devine committed
437
					} else {
438
						fmt.Println("No prompt template was specified for this model.")
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
					}
				default:
					fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
				}
			} else {
				usageShow()
			}
		case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
			args := strings.Fields(line)
			if len(args) > 1 {
				switch args[1] {
				case "set", "/set":
					usageSet()
				case "show", "/show":
					usageShow()
				case "shortcut", "shortcuts":
					usageShortcuts()
				}
			} else {
				usage()
			}
460
		case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
461
462
463
464
465
			return nil
		case strings.HasPrefix(line, "/"):
			args := strings.Fields(line)
			isFile := false

466
			if opts.MultiModal {
467
468
469
470
471
472
473
474
				for _, f := range extractFileNames(line) {
					if strings.HasPrefix(f, args[0]) {
						isFile = true
						break
					}
				}
			}

475
			if !isFile {
476
477
478
				fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
				continue
			}
479
480

			sb.WriteString(line)
481
		default:
482
			sb.WriteString(line)
483
484
		}

485
		if sb.Len() > 0 && multiline == MultilineNone {
486
487
			newMessage := api.Message{Role: "user", Content: sb.String()}

488
			if opts.MultiModal {
489
				msg, images, err := extractFileData(sb.String())
490
491
492
				if err != nil {
					return err
				}
493

494
				newMessage.Content = msg
495
				newMessage.Images = images
496
			}
497

498
499
500
501
			opts.Messages = append(opts.Messages, newMessage)

			assistant, err := chat(cmd, opts)
			if err != nil {
Michael Yang's avatar
Michael Yang committed
502
503
				if strings.Contains(err.Error(), "does not support thinking") ||
					strings.Contains(err.Error(), "invalid think value") {
504
505
506
507
					fmt.Printf("error: %v\n", err)
					sb.Reset()
					continue
				}
508
509
				return err
			}
510
511
512
			if assistant != nil {
				opts.Messages = append(opts.Messages, *assistant)
			}
513

514
			sb.Reset()
515
516
517
518
		}
	}
}

519
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
520
521
522
523
524
525
526
	parentModel := opts.ParentModel

	modelName := model.ParseName(parentModel)
	if !modelName.IsValid() {
		parentModel = ""
	}

527
	req := &api.CreateRequest{
528
529
		Model: name,
		From:  cmp.Or(parentModel, opts.Model),
530
	}
Michael Yang's avatar
Michael Yang committed
531

532
	if opts.System != "" {
533
		req.System = opts.System
534
535
	}

536
537
	if len(opts.Options) > 0 {
		req.Parameters = opts.Options
538
539
	}

540
541
	if len(opts.Messages) > 0 {
		req.Messages = opts.Messages
542
543
	}

544
	return req
545
546
}

547
func normalizeFilePath(fp string) string {
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
	return strings.NewReplacer(
		"\\ ", " ", // Escaped space
		"\\(", "(", // Escaped left parenthesis
		"\\)", ")", // Escaped right parenthesis
		"\\[", "[", // Escaped left square bracket
		"\\]", "]", // Escaped right square bracket
		"\\{", "{", // Escaped left curly brace
		"\\}", "}", // Escaped right curly brace
		"\\$", "$", // Escaped dollar sign
		"\\&", "&", // Escaped ampersand
		"\\;", ";", // Escaped semicolon
		"\\'", "'", // Escaped single quote
		"\\\\", "\\", // Escaped backslash
		"\\*", "*", // Escaped asterisk
		"\\?", "?", // Escaped question mark
563
		"\\~", "~", // Escaped tilde
564
	).Replace(fp)
565
566
567
}

func extractFileNames(input string) []string {
568
	// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
569
	// and followed by more characters and a file extension
570
	// This will capture non filename strings, but we'll check for file existence to remove mismatches
571
	regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
572
573
574
575
576
	re := regexp.MustCompile(regexPattern)

	return re.FindAllString(input, -1)
}

577
func extractFileData(input string) (string, []api.ImageData, error) {
578
	filePaths := extractFileNames(input)
579
	var imgs []api.ImageData
580
581
582
583

	for _, fp := range filePaths {
		nfp := normalizeFilePath(fp)
		data, err := getImageData(nfp)
584
585
586
		if errors.Is(err, os.ErrNotExist) {
			continue
		} else if err != nil {
587
			fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
588
589
			return "", imgs, err
		}
590
		fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
591
592
		input = strings.ReplaceAll(input, "'"+nfp+"'", "")
		input = strings.ReplaceAll(input, "'"+fp+"'", "")
593
594
595
		input = strings.ReplaceAll(input, fp, "")
		imgs = append(imgs, data)
	}
596
	return strings.TrimSpace(input), imgs, nil
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
}

func getImageData(filePath string) ([]byte, error) {
	file, err := os.Open(filePath)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	buf := make([]byte, 512)
	_, err = file.Read(buf)
	if err != nil {
		return nil, err
	}

	contentType := http.DetectContentType(buf)
613
	allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
614
615
616
617
618
619
620
621
622
623
624
625
	if !slices.Contains(allowedTypes, contentType) {
		return nil, fmt.Errorf("invalid image type: %s", contentType)
	}

	info, err := file.Stat()
	if err != nil {
		return nil, err
	}

	// Check if the file size exceeds 100MB
	var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
	if info.Size() > maxSize {
Michael Yang's avatar
lint  
Michael Yang committed
626
		return nil, errors.New("file size exceeds maximum limit (100MB)")
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
	}

	buf = make([]byte, info.Size())
	_, err = file.Seek(0, 0)
	if err != nil {
		return nil, err
	}

	_, err = io.ReadFull(file, buf)
	if err != nil {
		return nil, err
	}

	return buf, nil
}