Unverified Commit 77ccbf04 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

Merge pull request #6128 from ollama/mxyng/lint

enable gofmt/gofumpt/goimports/tenv
parents 4addf6b5 b732beba
......@@ -6,10 +6,11 @@ import (
"os"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
)
func TestEstimateGPULayers(t *testing.T) {
......
......@@ -184,15 +184,15 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params := []string{
"--model", model,
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
"--ctx-size", strconv.Itoa(opts.NumCtx),
"--batch-size", strconv.Itoa(opts.NumBatch),
"--embedding",
}
params = append(params, "--log-disable")
if opts.NumGPU >= 0 {
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU))
}
if envconfig.Debug() {
......@@ -200,7 +200,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
}
if opts.MainGPU > 0 {
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
params = append(params, "--main-gpu", strconv.Itoa(opts.MainGPU))
}
if len(adapters) > 0 {
......@@ -214,7 +214,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
}
if opts.NumThread > 0 {
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
}
if !opts.F16KV {
......@@ -260,7 +260,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa")
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
params = append(params, "--parallel", strconv.Itoa(numParallel))
if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit)
......@@ -425,7 +425,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
if strings.Contains(s.status.LastErrMsg, "unknown model") {
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
}
s.done <- fmt.Errorf(s.status.LastErrMsg)
s.done <- errors.New(s.status.LastErrMsg)
} else {
s.done <- err
}
......
......@@ -3,8 +3,9 @@ package main
import (
"context"
"github.com/ollama/ollama/cmd"
"github.com/spf13/cobra"
"github.com/ollama/ollama/cmd"
)
func main() {
......
......@@ -5,6 +5,7 @@ import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
......@@ -14,6 +15,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
......@@ -367,24 +369,24 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid message format")
return nil, errors.New("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, fmt.Errorf("invalid message format")
return nil, errors.New("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
return nil, errors.New("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
return nil, errors.New("invalid message format")
}
}
......@@ -400,17 +402,17 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
if !valid {
return nil, fmt.Errorf("invalid image input")
return nil, errors.New("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, fmt.Errorf("invalid message format")
return nil, errors.New("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default:
return nil, fmt.Errorf("invalid message format")
return nil, errors.New("invalid message format")
}
}
default:
......@@ -423,7 +425,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
toolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
if err != nil {
return nil, fmt.Errorf("invalid tool call arguments")
return nil, errors.New("invalid tool call arguments")
}
}
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
......@@ -737,14 +739,12 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
......
......@@ -12,13 +12,16 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
)
const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image
const (
prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
imageURL = prefix + image
)
func prepareRequest(req *http.Request, body any) {
bodyBytes, _ := json.Marshal(body)
......
......@@ -82,7 +82,7 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
}
func TestParseFileFrom(t *testing.T) {
var cases = []struct {
cases := []struct {
input string
expected []Command
err error
......@@ -185,7 +185,7 @@ BADCOMMAND param1 value1
}
func TestParseFileMessages(t *testing.T) {
var cases = []struct {
cases := []struct {
input string
expected []Command
err error
......@@ -276,7 +276,7 @@ MESSAGE system`,
}
func TestParseFileQuoted(t *testing.T) {
var cases = []struct {
cases := []struct {
multiline string
expected []Command
err error
......@@ -430,7 +430,7 @@ TEMPLATE """
}
func TestParseFileParameters(t *testing.T) {
var cases = map[string]struct {
cases := map[string]struct {
name, value string
}{
"numa true": {"numa", "true"},
......@@ -491,7 +491,7 @@ func TestParseFileParameters(t *testing.T) {
}
func TestParseFileComments(t *testing.T) {
var cases = []struct {
cases := []struct {
input string
expected []Command
}{
......@@ -516,7 +516,7 @@ FROM foo
}
func TestParseFileFormatParseFile(t *testing.T) {
var cases = []string{
cases := []string{
`
FROM foo
ADAPTER adapter1
......
......@@ -6,8 +6,9 @@ import (
"strings"
"time"
"github.com/ollama/ollama/format"
"golang.org/x/term"
"github.com/ollama/ollama/format"
)
type Bar struct {
......
......@@ -13,7 +13,7 @@ type Buffer struct {
DisplayPos int
Pos int
Buf *arraylist.List
//LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end
// LineHasSpace is an arraylist of bools to keep track of whether a line has a space at the end
LineHasSpace *arraylist.List
Prompt *Prompt
LineWidth int
......@@ -56,7 +56,7 @@ func (b *Buffer) GetLineSpacing(line int) bool {
func (b *Buffer) MoveLeft() {
if b.Pos > 0 {
//asserts that we retrieve a rune
// asserts that we retrieve a rune
if e, ok := b.Buf.Get(b.Pos - 1); ok {
if r, ok := e.(rune); ok {
rLength := runewidth.RuneWidth(r)
......
......@@ -4,9 +4,7 @@ import (
"errors"
)
var (
ErrInterrupt = errors.New("Interrupt")
)
var ErrInterrupt = errors.New("Interrupt")
type InterruptError struct {
Line []rune
......
......@@ -7,8 +7,10 @@ import (
"unsafe"
)
const tcgets = 0x5401
const tcsets = 0x5402
const (
tcgets = 0x5401
tcsets = 0x5402
)
func getTermios(fd uintptr) (*Termios, error) {
termios := new(Termios)
......
......@@ -28,8 +28,10 @@ import (
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
var errPartStalled = errors.New("part stalled")
var (
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
)
var blobDownloadManager sync.Map
......
......@@ -828,7 +828,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
return errors.New("insecure protocol http")
}
manifest, _, err := GetManifest(mp)
......@@ -895,7 +895,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
return errors.New("insecure protocol http")
}
fn(api.ProgressResponse{Status: "pulling manifest"})
......@@ -1010,7 +1010,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
var errUnauthorized = errors.New("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token
func getTokenSubject(token string) string {
......
......@@ -2,9 +2,9 @@ package server
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"os"
......@@ -88,7 +88,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
m.filepath = p
m.fi = fi
m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
m.digest = hex.EncodeToString(sha256sum.Sum(nil))
return &m, nil
}
......
......@@ -14,7 +14,7 @@ func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
t.Fatal(err)
}
......
......@@ -9,6 +9,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
......
......@@ -6,6 +6,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
......
......@@ -55,8 +55,10 @@ func init() {
gin.SetMode(mode)
}
var errRequired = errors.New("is required")
var errBadTemplate = errors.New("template error")
var (
errRequired = errors.New("is required")
errBadTemplate = errors.New("template error")
)
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
......@@ -369,7 +371,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input)
if err != nil {
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
......@@ -430,7 +431,6 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
}
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
......@@ -556,7 +556,7 @@ func checkNameExists(name model.Name) error {
for n := range names {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
return fmt.Errorf("a model with that name already exists")
return errors.New("a model with that name already exists")
}
}
......@@ -729,7 +729,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
n := model.ParseName(req.Model)
if !n.IsValid() {
return nil, fmt.Errorf("invalid model name")
return nil, errors.New("invalid model name")
}
manifest, err := ParseNamedManifest(n)
......@@ -993,7 +993,7 @@ func allowedHost(host string) bool {
return true
}
var tlds = []string{
tlds := []string{
"localhost",
"local",
"internal",
......
......@@ -13,6 +13,7 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
......@@ -489,7 +490,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
......@@ -501,7 +502,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
......@@ -513,7 +514,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
......
......@@ -9,6 +9,7 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
......
......@@ -8,6 +8,7 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment