Unverified Commit 77ccbf04 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

Merge pull request #6128 from ollama/mxyng/lint

enable gofmt/gofumpt/goimports/tenv
parents 4addf6b5 b732beba
...@@ -6,10 +6,11 @@ import ( ...@@ -6,10 +6,11 @@ import (
"os" "os"
"testing" "testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
) )
func TestEstimateGPULayers(t *testing.T) { func TestEstimateGPULayers(t *testing.T) {
......
...@@ -184,15 +184,15 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -184,15 +184,15 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params := []string{ params := []string{
"--model", model, "--model", model,
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx), "--ctx-size", strconv.Itoa(opts.NumCtx),
"--batch-size", fmt.Sprintf("%d", opts.NumBatch), "--batch-size", strconv.Itoa(opts.NumBatch),
"--embedding", "--embedding",
} }
params = append(params, "--log-disable") params = append(params, "--log-disable")
if opts.NumGPU >= 0 { if opts.NumGPU >= 0 {
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU))
} }
if envconfig.Debug() { if envconfig.Debug() {
...@@ -200,7 +200,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -200,7 +200,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
if opts.MainGPU > 0 { if opts.MainGPU > 0 {
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU)) params = append(params, "--main-gpu", strconv.Itoa(opts.MainGPU))
} }
if len(adapters) > 0 { if len(adapters) > 0 {
...@@ -214,7 +214,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -214,7 +214,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
if opts.NumThread > 0 { if opts.NumThread > 0 {
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread)) params = append(params, "--threads", strconv.Itoa(opts.NumThread))
} }
if !opts.F16KV { if !opts.F16KV {
...@@ -260,7 +260,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -260,7 +260,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa") params = append(params, "--numa")
} }
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) params = append(params, "--parallel", strconv.Itoa(numParallel))
if estimate.TensorSplit != "" { if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit) params = append(params, "--tensor-split", estimate.TensorSplit)
...@@ -425,7 +425,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -425,7 +425,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
if strings.Contains(s.status.LastErrMsg, "unknown model") { if strings.Contains(s.status.LastErrMsg, "unknown model") {
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade" s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
} }
s.done <- fmt.Errorf(s.status.LastErrMsg) s.done <- errors.New(s.status.LastErrMsg)
} else { } else {
s.done <- err s.done <- err
} }
......
...@@ -3,8 +3,9 @@ package main ...@@ -3,8 +3,9 @@ package main
import ( import (
"context" "context"
"github.com/ollama/ollama/cmd"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/ollama/ollama/cmd"
) )
func main() { func main() {
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
...@@ -14,6 +15,7 @@ import ( ...@@ -14,6 +15,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
...@@ -367,24 +369,24 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -367,24 +369,24 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
for _, c := range content { for _, c := range content {
data, ok := c.(map[string]any) data, ok := c.(map[string]any)
if !ok { if !ok {
return nil, fmt.Errorf("invalid message format") return nil, errors.New("invalid message format")
} }
switch data["type"] { switch data["type"] {
case "text": case "text":
text, ok := data["text"].(string) text, ok := data["text"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid message format") return nil, errors.New("invalid message format")
} }
messages = append(messages, api.Message{Role: msg.Role, Content: text}) messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url": case "image_url":
var url string var url string
if urlMap, ok := data["image_url"].(map[string]any); ok { if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok { if url, ok = urlMap["url"].(string); !ok {
return nil, fmt.Errorf("invalid message format") return nil, errors.New("invalid message format")
} }
} else { } else {
if url, ok = data["image_url"].(string); !ok { if url, ok = data["image_url"].(string); !ok {
return nil, fmt.Errorf("invalid message format") return nil, errors.New("invalid message format")
} }
} }
...@@ -400,17 +402,17 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -400,17 +402,17 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
} }
if !valid { if !valid {
return nil, fmt.Errorf("invalid image input") return nil, errors.New("invalid image input")
} }
img, err := base64.StdEncoding.DecodeString(url) img, err := base64.StdEncoding.DecodeString(url)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid message format") return nil, errors.New("invalid message format")
} }
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}}) messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default: default:
return nil, fmt.Errorf("invalid message format") return nil, errors.New("invalid message format")
} }
} }
default: default:
...@@ -423,7 +425,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -423,7 +425,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
toolCalls[i].Function.Name = tc.Function.Name toolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments) err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid tool call arguments") return nil, errors.New("invalid tool call arguments")
} }
} }
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls}) messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
...@@ -737,14 +739,12 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) { ...@@ -737,14 +739,12 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
func (w *EmbedWriter) writeResponse(data []byte) (int, error) { func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse) err := json.Unmarshal(data, &embedResponse)
if err != nil { if err != nil {
return 0, err return 0, err
} }
w.ResponseWriter.Header().Set("Content-Type", "application/json") w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse)) err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil { if err != nil {
return 0, err return 0, err
} }
......
...@@ -12,13 +12,16 @@ import ( ...@@ -12,13 +12,16 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
) )
const prefix = `data:image/jpeg;base64,` const (
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` prefix = `data:image/jpeg;base64,`
const imageURL = prefix + image image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
imageURL = prefix + image
)
func prepareRequest(req *http.Request, body any) { func prepareRequest(req *http.Request, body any) {
bodyBytes, _ := json.Marshal(body) bodyBytes, _ := json.Marshal(body)
......
...@@ -82,7 +82,7 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|> ...@@ -82,7 +82,7 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
} }
func TestParseFileFrom(t *testing.T) { func TestParseFileFrom(t *testing.T) {
var cases = []struct { cases := []struct {
input string input string
expected []Command expected []Command
err error err error
...@@ -185,7 +185,7 @@ BADCOMMAND param1 value1 ...@@ -185,7 +185,7 @@ BADCOMMAND param1 value1
} }
func TestParseFileMessages(t *testing.T) { func TestParseFileMessages(t *testing.T) {
var cases = []struct { cases := []struct {
input string input string
expected []Command expected []Command
err error err error
...@@ -276,7 +276,7 @@ MESSAGE system`, ...@@ -276,7 +276,7 @@ MESSAGE system`,
} }
func TestParseFileQuoted(t *testing.T) { func TestParseFileQuoted(t *testing.T) {
var cases = []struct { cases := []struct {
multiline string multiline string
expected []Command expected []Command
err error err error
...@@ -430,7 +430,7 @@ TEMPLATE """ ...@@ -430,7 +430,7 @@ TEMPLATE """
} }
func TestParseFileParameters(t *testing.T) { func TestParseFileParameters(t *testing.T) {
var cases = map[string]struct { cases := map[string]struct {
name, value string name, value string
}{ }{
"numa true": {"numa", "true"}, "numa true": {"numa", "true"},
...@@ -491,7 +491,7 @@ func TestParseFileParameters(t *testing.T) { ...@@ -491,7 +491,7 @@ func TestParseFileParameters(t *testing.T) {
} }
func TestParseFileComments(t *testing.T) { func TestParseFileComments(t *testing.T) {
var cases = []struct { cases := []struct {
input string input string
expected []Command expected []Command
}{ }{
...@@ -516,7 +516,7 @@ FROM foo ...@@ -516,7 +516,7 @@ FROM foo
} }
func TestParseFileFormatParseFile(t *testing.T) { func TestParseFileFormatParseFile(t *testing.T) {
var cases = []string{ cases := []string{
` `
FROM foo FROM foo
ADAPTER adapter1 ADAPTER adapter1
......
...@@ -6,8 +6,9 @@ import ( ...@@ -6,8 +6,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/ollama/ollama/format"
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/format"
) )
type Bar struct { type Bar struct {
......
...@@ -13,7 +13,7 @@ type Buffer struct { ...@@ -13,7 +13,7 @@ type Buffer struct {
DisplayPos int DisplayPos int
Pos int Pos int
Buf *arraylist.List Buf *arraylist.List
//LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end // LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end
LineHasSpace *arraylist.List LineHasSpace *arraylist.List
Prompt *Prompt Prompt *Prompt
LineWidth int LineWidth int
...@@ -56,7 +56,7 @@ func (b *Buffer) GetLineSpacing(line int) bool { ...@@ -56,7 +56,7 @@ func (b *Buffer) GetLineSpacing(line int) bool {
func (b *Buffer) MoveLeft() { func (b *Buffer) MoveLeft() {
if b.Pos > 0 { if b.Pos > 0 {
//asserts that we retrieve a rune // asserts that we retrieve a rune
if e, ok := b.Buf.Get(b.Pos - 1); ok { if e, ok := b.Buf.Get(b.Pos - 1); ok {
if r, ok := e.(rune); ok { if r, ok := e.(rune); ok {
rLength := runewidth.RuneWidth(r) rLength := runewidth.RuneWidth(r)
......
...@@ -4,9 +4,7 @@ import ( ...@@ -4,9 +4,7 @@ import (
"errors" "errors"
) )
var ( var ErrInterrupt = errors.New("Interrupt")
ErrInterrupt = errors.New("Interrupt")
)
type InterruptError struct { type InterruptError struct {
Line []rune Line []rune
......
...@@ -7,8 +7,10 @@ import ( ...@@ -7,8 +7,10 @@ import (
"unsafe" "unsafe"
) )
const tcgets = 0x5401 const (
const tcsets = 0x5402 tcgets = 0x5401
tcsets = 0x5402
)
func getTermios(fd uintptr) (*Termios, error) { func getTermios(fd uintptr) (*Termios, error) {
termios := new(Termios) termios := new(Termios)
......
...@@ -28,8 +28,10 @@ import ( ...@@ -28,8 +28,10 @@ import (
const maxRetries = 6 const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded") var (
var errPartStalled = errors.New("part stalled") errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
)
var blobDownloadManager sync.Map var blobDownloadManager sync.Map
......
...@@ -828,7 +828,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu ...@@ -828,7 +828,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure { if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http") return errors.New("insecure protocol http")
} }
manifest, _, err := GetManifest(mp) manifest, _, err := GetManifest(mp)
...@@ -895,7 +895,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu ...@@ -895,7 +895,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
} }
if mp.ProtocolScheme == "http" && !regOpts.Insecure { if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http") return errors.New("insecure protocol http")
} }
fn(api.ProgressResponse{Status: "pulling manifest"}) fn(api.ProgressResponse{Status: "pulling manifest"})
...@@ -1010,7 +1010,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) { ...@@ -1010,7 +1010,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
} }
var errUnauthorized = fmt.Errorf("unauthorized: access denied") var errUnauthorized = errors.New("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token // getTokenSubject returns the subject of a JWT token, it does not validate the token
func getTokenSubject(token string) string { func getTokenSubject(token string) string {
......
...@@ -2,9 +2,9 @@ package server ...@@ -2,9 +2,9 @@ package server
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"log/slog" "log/slog"
"os" "os"
...@@ -88,7 +88,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { ...@@ -88,7 +88,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
m.filepath = p m.filepath = p
m.fi = fi m.fi = fi
m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil)) m.digest = hex.EncodeToString(sha256sum.Sum(nil))
return &m, nil return &m, nil
} }
......
...@@ -14,7 +14,7 @@ func createManifest(t *testing.T, path, name string) { ...@@ -14,7 +14,7 @@ func createManifest(t *testing.T, path, name string) {
t.Helper() t.Helper()
p := filepath.Join(path, "manifests", name) p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil { if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
......
...@@ -55,8 +55,10 @@ func init() { ...@@ -55,8 +55,10 @@ func init() {
gin.SetMode(mode) gin.SetMode(mode)
} }
var errRequired = errors.New("is required") var (
var errBadTemplate = errors.New("template error") errRequired = errors.New("is required")
errBadTemplate = errors.New("template error")
)
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
...@@ -369,7 +371,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { ...@@ -369,7 +371,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
input[i] = s input[i] = s
} }
embeddings, err := r.Embed(c.Request.Context(), input) embeddings, err := r.Embed(c.Request.Context(), input)
if err != nil { if err != nil {
slog.Error("embedding generation failed", "error", err) slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
...@@ -430,7 +431,6 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { ...@@ -430,7 +431,6 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
} }
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}) embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
...@@ -556,7 +556,7 @@ func checkNameExists(name model.Name) error { ...@@ -556,7 +556,7 @@ func checkNameExists(name model.Name) error {
for n := range names { for n := range names {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name { if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
return fmt.Errorf("a model with that name already exists") return errors.New("a model with that name already exists")
} }
} }
...@@ -729,7 +729,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ...@@ -729,7 +729,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
n := model.ParseName(req.Model) n := model.ParseName(req.Model)
if !n.IsValid() { if !n.IsValid() {
return nil, fmt.Errorf("invalid model name") return nil, errors.New("invalid model name")
} }
manifest, err := ParseNamedManifest(n) manifest, err := ParseNamedManifest(n)
...@@ -993,7 +993,7 @@ func allowedHost(host string) bool { ...@@ -993,7 +993,7 @@ func allowedHost(host string) bool {
return true return true
} }
var tlds = []string{ tlds := []string{
"localhost", "localhost",
"local", "local",
"internal", "internal",
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
...@@ -489,7 +490,7 @@ func TestCreateTemplateSystem(t *testing.T) { ...@@ -489,7 +490,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
}) })
if w.Code != http.StatusBadRequest { if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code) t.Fatalf("expected status code 400, actual %d", w.Code)
} }
...@@ -501,7 +502,7 @@ func TestCreateTemplateSystem(t *testing.T) { ...@@ -501,7 +502,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
}) })
if w.Code != http.StatusBadRequest { if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code) t.Fatalf("expected status code 400, actual %d", w.Code)
} }
...@@ -513,7 +514,7 @@ func TestCreateTemplateSystem(t *testing.T) { ...@@ -513,7 +514,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
}) })
if w.Code != http.StatusBadRequest { if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code) t.Fatalf("expected status code 400, actual %d", w.Code)
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment