Commit 920a4b07 authored by Daniel Hiltgen's avatar Daniel Hiltgen
Browse files

Merge remote-tracking branch 'upstream/main' into pr3702

parents c496967e ee49844d
package parser
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log/slog"
"slices"
)
type Command struct {
Name string
Args string
}
func (c *Command) Reset() {
c.Name = ""
c.Args = ""
}
func Parse(reader io.Reader) ([]Command, error) {
var commands []Command
var command, modelCommand Command
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
scanner.Split(scanModelfile)
for scanner.Scan() {
line := scanner.Bytes()
fields := bytes.SplitN(line, []byte(" "), 2)
if len(fields) == 0 || len(fields[0]) == 0 {
continue
}
switch string(bytes.ToUpper(fields[0])) {
case "FROM":
command.Name = "model"
command.Args = string(bytes.TrimSpace(fields[1]))
// copy command for validation
modelCommand = command
case "ADAPTER":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(bytes.TrimSpace(fields[1]))
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1])
case "PARAMETER":
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("missing value for %s", fields)
}
command.Name = string(fields[0])
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default:
if !bytes.HasPrefix(fields[0], []byte("#")) {
// log a warning for unknown commands
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
}
continue
}
commands = append(commands, command)
command.Reset()
}
if modelCommand.Args == "" {
return nil, errors.New("no FROM line for the model was specified")
}
return commands, scanner.Err()
}
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
if err != nil {
return 0, nil, err
}
if advance > 0 && token != nil {
return advance, token, nil
}
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
if err != nil {
return 0, nil, err
}
if advance > 0 && token != nil {
return advance, token, nil
}
return bufio.ScanLines(data, atEOF)
}
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
newline := bytes.IndexByte(data, '\n')
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
end := bytes.Index(data[start+len(openBytes):], closeBytes)
if end < 0 {
if atEOF {
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
} else {
return 0, nil, nil
}
}
n := start + len(openBytes) + end + len(closeBytes)
newData := data[:start]
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
return n, newData, nil
}
return 0, nil, nil
}
package parser
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_Parser(t *testing.T) {
input := `
FROM model1
ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "model1"},
{Name: "adapter", Args: "adapter1"},
{Name: "license", Args: "MIT"},
{Name: "param1", Args: "value1"},
{Name: "param2", Args: "value2"},
{Name: "template", Args: "template1"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_NoFromLine(t *testing.T) {
input := `
PARAMETER param1 value1
PARAMETER param2 value2
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "no FROM line")
}
func Test_Parser_MissingValue(t *testing.T) {
input := `
FROM foo
PARAMETER param1
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "missing value for [param1]")
}
func Test_Parser_Messages(t *testing.T) {
input := `
FROM foo
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_Messages_BadRole(t *testing.T) {
input := `
FROM foo
MESSAGE badguy I'm a bad guy!
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
}
......@@ -218,7 +218,7 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ:
fd := int(syscall.Stdin)
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter:
case CharEnter, CharCtrlJ:
output := buf.String()
if output != "" {
i.History.Add([]rune(output))
......@@ -232,7 +232,7 @@ func (i *Instance) Readline() (string, error) {
metaDel = false
continue
}
if r >= CharSpace || r == CharEnter {
if r >= CharSpace || r == CharEnter || r == CharCtrlJ {
buf.Add(r)
}
}
......
......@@ -7,6 +7,8 @@
$ErrorActionPreference = "Stop"
function checkEnv() {
$script:TARGET_ARCH=$Env:PROCESSOR_ARCHITECTURE.ToLower()
Write-host "Building for ${script:TARGET_ARCH}"
write-host "Locating required tools and paths"
$script:SRC_DIR=$PWD
if (!$env:VCToolsRedistDir) {
......@@ -30,7 +32,7 @@ function checkEnv() {
$script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0]
$script:DEPS_DIR="${script:SRC_DIR}\dist\windeps"
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
$env:CGO_ENABLED="1"
echo "Checking version"
if (!$env:VERSION) {
......@@ -81,8 +83,8 @@ function buildOllama() {
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
New-Item -ItemType Directory -Path .\dist -Force
cp .\ollama.exe .\dist\ollama-windows-amd64.exe
New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force
cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\
}
function buildApp() {
......@@ -101,7 +103,6 @@ function buildApp() {
function gatherDependencies() {
write-host "Gathering runtime dependencies"
cd "${script:SRC_DIR}"
rm -ea 0 -recurse -force -path "${script:DEPS_DIR}"
md "${script:DEPS_DIR}" -ea 0 > $null
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
......@@ -110,9 +111,6 @@ function gatherDependencies() {
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\"
cp "${script:NVIDIA_DIR}\cudart64_*.dll" "${script:DEPS_DIR}\"
cp "${script:NVIDIA_DIR}\cublas64_*.dll" "${script:DEPS_DIR}\"
cp "${script:NVIDIA_DIR}\cublasLt64_*.dll" "${script:DEPS_DIR}\"
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
if ("${env:KEY_CONTAINER}") {
......@@ -124,7 +122,6 @@ function gatherDependencies() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildInstaller() {
......@@ -132,19 +129,25 @@ function buildInstaller() {
cd "${script:SRC_DIR}\app"
$env:PKG_VERSION=$script:PKG_VERSION
if ("${env:KEY_CONTAINER}") {
& "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
} else {
& "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH .\ollama.iss
}
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
function distZip() {
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip"
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" -Force
}
try {
checkEnv
buildOllama
buildApp
gatherDependencies
buildInstaller
distZip
} catch {
write-host "Build Failed"
write-host $_
......
......@@ -166,8 +166,8 @@ fi
if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
# Look for pre-existing ROCm v6 before downloading the dependencies
for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm"; do
if [ -n "${search}" ] && [ -e "${search}/lib/libhipblas.so.2" ]; then
for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm" "/usr/lib64"; do
if [ -n "${search}" ] && [ -e "${search}/libhipblas.so.2" -o -e "${search}/lib/libhipblas.so.2" ]; then
status "Compatible AMD GPU ROCm library detected at ${search}"
install_success
exit 0
......
package envconfig
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
)
var (
// Set via OLLAMA_ORIGINS in the environment
AllowOrigins []string
// Set via OLLAMA_DEBUG in the environment
Debug bool
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int
// Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64
// Set via OLLAMA_NOPRUNE in the environment
NoPrune bool
// Set via OLLAMA_NUM_PARALLEL in the environment
NumParallel int
// Set via OLLAMA_RUNNERS_DIR in the environment
RunnersDir string
// Set via OLLAMA_TMPDIR in the environment
TmpDir string
)
func AsMap() map[string]string {
return map[string]string{
"OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins),
"OLLAMA_DEBUG": fmt.Sprintf("%v", Debug),
"OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary),
"OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners),
"OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests),
"OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM),
"OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune),
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
}
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
// Clean quotes and spaces from the value
func clean(key string) string {
return strings.Trim(os.Getenv(key), "\"' ")
}
func init() {
// default values
NumParallel = 1
MaxRunners = 1
MaxQueuedRequests = 512
LoadConfig()
}
func LoadConfig() {
if debug := clean("OLLAMA_DEBUG"); debug != "" {
d, err := strconv.ParseBool(debug)
if err == nil {
Debug = d
} else {
Debug = true
}
}
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
if runtime.GOOS == "windows" && RunnersDir == "" {
// On Windows we do not carry the payloads inside the main executable
appExe, err := os.Executable()
if err != nil {
slog.Error("failed to lookup executable path", "error", err)
}
cwd, err := os.Getwd()
if err != nil {
slog.Error("failed to lookup working directory", "error", err)
}
var paths []string
for _, root := range []string{filepath.Dir(appExe), cwd} {
paths = append(paths,
filepath.Join(root),
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
}
// Try a few variations to improve developer experience when building from source in the local tree
for _, p := range paths {
candidate := filepath.Join(p, "ollama_runners")
_, err := os.Stat(candidate)
if err == nil {
RunnersDir = candidate
break
}
}
if RunnersDir == "" {
slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
}
}
TmpDir = clean("OLLAMA_TMPDIR")
userLimit := clean("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseUint(userLimit, 10, 64)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
} else {
MaxVRAM = avail
}
}
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
val, err := strconv.Atoi(onp)
if err != nil || val <= 0 {
slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err)
} else {
NumParallel = val
}
}
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
NoPrune = true
}
if origins := clean("OLLAMA_ORIGINS"); origins != "" {
AllowOrigins = strings.Split(origins, ",")
}
for _, allowOrigin := range defaultAllowOrigins {
AllowOrigins = append(AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s:*", allowOrigin),
fmt.Sprintf("https://%s:*", allowOrigin),
)
}
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
if err != nil {
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
MaxRunners = m
}
}
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" {
p, err := strconv.Atoi(onp)
if err != nil || p <= 0 {
slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err)
} else {
MaxQueuedRequests = p
}
}
}
package envconfig
import (
"os"
"testing"
"github.com/stretchr/testify/require"
)
func TestConfig(t *testing.T) {
os.Setenv("OLLAMA_DEBUG", "")
LoadConfig()
require.False(t, Debug)
os.Setenv("OLLAMA_DEBUG", "false")
LoadConfig()
require.False(t, Debug)
os.Setenv("OLLAMA_DEBUG", "1")
LoadConfig()
require.True(t, Debug)
}
package server
import (
"archive/zip"
"bytes"
"cmp"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"log/slog"
"net/http"
......@@ -20,15 +20,16 @@ import (
"runtime"
"strconv"
"strings"
"text/template"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
......@@ -51,7 +52,6 @@ type Model struct {
System string
License []string
Digest string
Size int64
Options map[string]interface{}
Messages []Message
}
......@@ -60,6 +60,76 @@ func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
}
func (m *Model) String() string {
var modelfile model.File
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "model",
Args: m.ModelPath,
})
for _, adapter := range m.AdapterPaths {
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "adapter",
Args: adapter,
})
}
for _, projector := range m.ProjectorPaths {
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "model",
Args: projector,
})
}
if m.Template != "" {
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "template",
Args: m.Template,
})
}
if m.System != "" {
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "system",
Args: m.System,
})
}
for k, v := range m.Options {
switch v := v.(type) {
case []any:
for _, s := range v {
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: k,
Args: fmt.Sprintf("%v", s),
})
}
default:
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: k,
Args: fmt.Sprintf("%v", v),
})
}
}
for _, license := range m.License {
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "license",
Args: license,
})
}
for _, msg := range m.Messages {
modelfile.Commands = append(modelfile.Commands, model.Command{
Name: "message",
Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
})
}
return modelfile.String()
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
......@@ -85,50 +155,11 @@ type ConfigV2 struct {
RootFS RootFS `json:"rootfs"`
}
func (c *ConfigV2) SetModelFormat(format string) {
if c.ModelFormat == "" {
c.ModelFormat = format
}
}
func (c *ConfigV2) SetModelFamily(families ...string) {
for _, family := range families {
if c.ModelFamily == "" {
c.ModelFamily = family
}
if !slices.Contains(c.ModelFamilies, family) {
c.ModelFamilies = append(c.ModelFamilies, family)
}
}
}
func (c *ConfigV2) SetModelType(modelType string) {
if c.ModelType == "" {
c.ModelType = modelType
}
}
func (c *ConfigV2) SetFileType(fileType string) {
if c.FileType == "" {
c.FileType = fileType
}
}
type RootFS struct {
Type string `json:"type"`
DiffIDs []string `json:"diff_ids"`
}
func (m *ManifestV2) GetTotalSize() (total int64) {
for _, layer := range m.Layers {
total += layer.Size
}
total += m.Config.Size
return total
}
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
......@@ -169,7 +200,6 @@ func GetModel(name string) (*Model, error) {
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Size: manifest.GetTotalSize(),
}
filename, err := GetBlobsPath(manifest.Config.Digest)
......@@ -259,7 +289,7 @@ func GetModel(name string) (*Model, error) {
return model, nil
}
func realpath(mfDir, from string) string {
func realpath(rel, from string) string {
abspath, err := filepath.Abs(from)
if err != nil {
return from
......@@ -276,22 +306,15 @@ func realpath(mfDir, from string) string {
return filepath.Join(home, from[2:])
}
if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil {
if _, err := os.Stat(filepath.Join(rel, from)); err == nil {
// this is a file relative to the Modelfile
return filepath.Join(mfDir, from)
return filepath.Join(rel, from)
}
return abspath
}
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
deleteMap := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range append(manifest.Layers, manifest.Config) {
deleteMap[layer.Digest] = struct{}{}
}
}
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
......@@ -300,250 +323,181 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
},
}
var layers Layers
messages := []string{}
params := make(map[string][]string)
fromParams := make(map[string]any)
var messages []*api.Message
parameters := make(map[string]any)
for _, c := range commands {
var layers []*Layer
for _, c := range modelfile.Commands {
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name {
case "model":
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
case "model", "adapter":
var baseLayers []*layerWithGGML
if name := model.ParseName(c.Args); name.IsValid() {
baseLayers, err = parseFromModel(ctx, name, fn)
if err != nil {
return err
}
c.Args = blobPath
}
pathName := realpath(modelFileDir, c.Args)
ggufName, err := convertModel(name, pathName, fn)
if err != nil {
var pathErr *fs.PathError
switch {
case errors.Is(err, zip.ErrFormat):
// it's not a safetensor archive
case errors.As(err, &pathErr):
// it's not a file on disk, could be a model reference
default:
return err
}
}
if ggufName != "" {
pathName = ggufName
defer os.RemoveAll(ggufName)
if quantization != "" {
quantization = strings.ToUpper(quantization)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
if err != nil {
return err
}
defer os.RemoveAll(tempfile.Name())
if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
return err
}
if err := tempfile.Close(); err != nil {
return err
}
pathName = tempfile.Name()
}
}
bin, err := os.Open(pathName)
if err != nil {
// not a file on disk so must be a model reference
modelpath := ParseModelPath(c.Args)
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &registryOptions{}, fn); err != nil {
return err
}
manifest, _, err = GetManifest(modelpath)
if err != nil {
return err
}
case err != nil:
} else if strings.HasPrefix(c.Args, "@") {
blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
fn(api.ProgressResponse{Status: "reading model metadata"})
fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
blob, err := os.Open(blobpath)
if err != nil {
return err
}
defer blob.Close()
fromConfigFile, err := os.Open(fromConfigPath)
baseLayers, err = parseFromFile(ctx, blob, fn)
if err != nil {
return err
}
defer fromConfigFile.Close()
} else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
defer file.Close()
var fromConfig ConfigV2
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
baseLayers, err = parseFromFile(ctx, file, fn)
if err != nil {
return err
}
} else {
return fmt.Errorf("invalid model reference: %s", c.Args)
}
// if the model is still not in gguf format, error out
if fromConfig.ModelFormat != "gguf" {
return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args)
}
for _, baseLayer := range baseLayers {
if quantization != "" &&
baseLayer.MediaType == "application/vnd.ollama.image.model" &&
baseLayer.GGML != nil &&
baseLayer.GGML.Name() == "gguf" {
want, err := llm.ParseFileType(quantization)
if err != nil {
return err
}
config.SetModelFormat(fromConfig.ModelFormat)
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
config.SetModelType(fromConfig.ModelType)
config.SetFileType(fromConfig.FileType)
ft := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16 and F32 models")
} else if want != ft {
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil {
return err
}
fromParamsFile, err := os.Open(fromParamsPath)
temp, err := os.CreateTemp(filepath.Dir(blob), quantization)
if err != nil {
return err
}
defer fromParamsFile.Close()
defer temp.Close()
defer os.Remove(temp.Name())
if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
if err := llm.Quantize(blob, temp.Name(), want); err != nil {
return err
}
}
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil {
return err
baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
if err != nil {
return err
}
}
layers.Add(layer)
}
deleteMap[manifest.Config.Digest] = struct{}{}
continue
}
defer bin.Close()
var offset int64
for {
fn(api.ProgressResponse{Status: "creating model layer"})
if _, err := bin.Seek(offset, io.SeekStart); err != nil {
return err
}
ggml, size, err := llm.DecodeGGML(bin)
if errors.Is(err, io.EOF) {
break
} else if errors.Is(err, llm.ErrUnsupportedFormat) {
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
} else if err != nil {
return err
}
config.SetModelFormat(ggml.Name())
config.SetModelFamily(ggml.KV().Architecture())
config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount()))
config.SetFileType(ggml.KV().FileType())
mediatype := mediatype
if ggml.KV().Architecture() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
sr := io.NewSectionReader(bin, offset, size)
layer, err := NewLayer(sr, mediatype)
if err != nil {
return err
if baseLayer.GGML != nil {
config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String())
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
}
layers.Add(layer)
offset += size
layers = append(layers, baseLayer.Layer)
}
case "license", "template", "system":
blob := strings.NewReader(c.Args)
layer, err := NewLayer(blob, mediatype)
if err != nil {
return err
}
case "adapter":
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
c.Args = blobPath
if c.Name != "license" {
// replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
return layer.MediaType == mediatype
})
}
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil {
return err
layers = append(layers, layer)
case "message":
role, content, ok := strings.Cut(c.Args, ": ")
if !ok {
return fmt.Errorf("invalid message: %s", c.Args)
}
defer bin.Close()
_, size, err := llm.DecodeGGML(bin)
messages = append(messages, &api.Message{Role: role, Content: content})
default:
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
if err != nil {
return err
}
sr := io.NewSectionReader(bin, 0, size)
layer, err := NewLayer(sr, mediatype)
if err != nil {
return err
for k, v := range ps {
if ks, ok := parameters[k].([]string); ok {
parameters[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
parameters[k] = vs
} else {
parameters[k] = v
}
}
}
}
layers.Add(layer)
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
var err2 error
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
switch layer.MediaType {
case "application/vnd.ollama.image.message":
// if there are new messages, remove the inherited ones
if len(messages) > 0 {
return true
}
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
return false
case "application/vnd.ollama.image.params":
// merge inherited parameters with new ones
r, err := layer.Open()
if err != nil {
return err
err2 = err
return false
}
defer r.Close()
layers.Add(layer)
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
var ps map[string]any
if err := json.NewDecoder(r).Decode(&ps); err != nil {
err2 = err
return false
}
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
for k, v := range ps {
if _, ok := parameters[k]; !ok {
parameters[k] = v
}
}
layers.Replace(layer)
case "message":
messages = append(messages, c.Args)
return true
default:
params[c.Name] = append(params[c.Name], c.Args)
return false
}
})
if err2 != nil {
return err2
}
if len(messages) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
msgs := make([]api.Message, 0)
for _, m := range messages {
// todo: handle images
msg := strings.SplitN(m, ": ", 2)
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
if err := json.NewEncoder(&b).Encode(messages); err != nil {
return err
}
......@@ -552,39 +506,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return err
}
layers.Replace(layer)
layers = append(layers, layer)
}
if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
formattedParams, err := api.FormatParams(params)
if err != nil {
return err
}
for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v
}
}
if len(parameters) > 0 {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
if err := json.NewEncoder(&b).Encode(parameters); err != nil {
return err
}
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return err
}
layers.Replace(layer)
layers = append(layers, layer)
}
digests := make([]string, len(layers.items))
for i, layer := range layers.items {
digests := make([]string, len(layers))
for i, layer := range layers {
digests[i] = layer.Digest
}
......@@ -595,36 +535,37 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return err
}
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return err
}
delete(deleteMap, configLayer.Digest)
for _, layer := range append(layers.items, configLayer) {
committed, err := layer.Commit()
if err != nil {
return err
for _, layer := range append(layers, layer) {
if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status})
}
}
status := "writing layer"
if !committed {
status = "using already created layer"
unref := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range manifest.Layers {
if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
}
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
delete(deleteMap, layer.Digest)
if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, configLayer, layers.items); err != nil {
if err := WriteManifest(name, layer, layers); err != nil {
return err
}
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, unref, false); err != nil {
return err
}
}
......@@ -633,104 +574,43 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
return nil
}
func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
r, err := zip.OpenReader(path)
if err != nil {
return "", err
}
defer r.Close()
tempDir, err := os.MkdirTemp("", "ollama-convert")
if err != nil {
return "", err
}
defer os.RemoveAll(tempDir)
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
fpath := filepath.Join(tempDir, f.Name)
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return "", err
}
rc, err := f.Open()
if err != nil {
return "", err
}
_, err = io.Copy(outFile, rc)
if err != nil {
return "", err
}
outFile.Close()
rc.Close()
}
mf, err := convert.GetModelFormat(tempDir)
if err != nil {
return "", err
}
params, err := mf.GetParams(tempDir)
if err != nil {
return "", err
}
mArch, err := mf.GetModelArch(name, tempDir, params)
if err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "processing tensors"})
if err := mArch.GetTensors(); err != nil {
return "", err
func CopyModel(src, dst model.Name) error {
if !dst.IsFullyQualified() {
return model.Unqualified(dst)
}
if err := mArch.LoadVocab(); err != nil {
return "", err
if !src.IsFullyQualified() {
return model.Unqualified(src)
}
fn(api.ProgressResponse{Status: "converting model"})
path, err = mArch.WriteGGUF()
if err != nil {
return "", err
if src.Filepath() == dst.Filepath() {
return nil
}
return path, nil
}
func CopyModel(src, dest string) error {
srcModelPath := ParseModelPath(src)
srcPath, err := srcModelPath.GetManifestPath()
manifests, err := GetManifestPath()
if err != nil {
return err
}
destModelPath := ParseModelPath(dest)
destPath, err := destModelPath.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
dstpath := filepath.Join(manifests, dst.Filepath())
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
return err
}
// copy the file
input, err := os.ReadFile(srcPath)
srcpath := filepath.Join(manifests, src.Filepath())
srcfile, err := os.Open(srcpath)
if err != nil {
fmt.Println("Error reading file:", err)
return err
}
defer srcfile.Close()
err = os.WriteFile(destPath, input, 0o644)
dstfile, err := os.Create(dstpath)
if err != nil {
fmt.Println("Error reading file:", err)
return err
}
defer dstfile.Close()
return nil
_, err = io.Copy(dstfile, srcfile)
return err
}
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
......@@ -890,67 +770,6 @@ func DeleteModel(name string) error {
return nil
}
func ShowModelfile(model *Model) (string, error) {
var mt struct {
*Model
From string
Parameters map[string][]any
}
mt.Parameters = make(map[string][]any)
for k, v := range model.Options {
if s, ok := v.([]any); ok {
mt.Parameters[k] = s
continue
}
mt.Parameters[k] = []any{v}
}
mt.Model = model
mt.From = model.ModelPath
if model.ParentModel != "" {
mt.From = model.ParentModel
}
modelFile := `# Modelfile generated by "ollama show"
# To build a new Modelfile based on this one, replace the FROM line with:
# FROM {{ .ShortName }}
FROM {{ .From }}
TEMPLATE """{{ .Template }}"""
{{- if .System }}
SYSTEM """{{ .System }}"""
{{- end }}
{{- range $adapter := .AdapterPaths }}
ADAPTER {{ $adapter }}
{{- end }}
{{- range $k, $v := .Parameters }}
{{- range $parameter := $v }}
PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
{{- end }}
{{- end }}`
tmpl, err := template.New("").Parse(modelFile)
if err != nil {
slog.Info(fmt.Sprintf("error parsing template: %q", err))
return "", err
}
var buf bytes.Buffer
if err = tmpl.Execute(&buf, mt); err != nil {
slog.Info(fmt.Sprintf("error executing template: %q", err))
return "", err
}
return buf.String(), nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
......@@ -972,9 +791,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
if errors.Is(err, errUnauthorized) {
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
}
return err
}
}
......@@ -1011,7 +827,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
if !envconfig.NoPrune {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
......@@ -1137,9 +953,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = fmt.Errorf("unauthorized")
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token
func getTokenSubject(token string) string {
parts := strings.Split(token, ".")
if len(parts) != 3 {
slog.Error("jwt token does not contain 3 parts")
return ""
}
payload := parts[1]
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
return ""
}
var payloadMap map[string]interface{}
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
return ""
}
sub, ok := payloadMap["sub"]
if !ok {
slog.Error("jwt does not contain 'sub' field")
return ""
}
return fmt.Sprintf("%s", sub)
}
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
anonymous := true // access will default to anonymous if no user is found associated with the public key
for i := 0; i < 2; i++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
......@@ -1158,6 +1005,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if err != nil {
return nil, err
}
anonymous = getTokenSubject(token) == "anonymous"
regOpts.Token = token
if body != nil {
_, err = body.Seek(0, io.SeekStart)
......@@ -1178,6 +1026,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
}
}
if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey()
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized
}
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
}
// user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized
}
......@@ -1255,7 +1113,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
}
}
var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
......
......@@ -5,39 +5,14 @@ import (
"fmt"
"io"
"os"
"strings"
"golang.org/x/exp/slices"
)
type Layers struct {
items []*Layer
}
func (ls *Layers) Add(layer *Layer) {
if layer.Size > 0 {
ls.items = append(ls.items, layer)
}
}
func (ls *Layers) Replace(layer *Layer) {
if layer.Size > 0 {
mediatype := layer.MediaType
layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
return l.MediaType == mediatype
})
ls.items = append(layers, layer)
}
}
type Layer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
tempFileName string
status string
}
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
......@@ -46,14 +21,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
return nil, err
}
const delimiter = "-"
pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
temp, err := os.CreateTemp(blobs, pattern)
temp, err := os.CreateTemp(blobs, "sha256-")
if err != nil {
return nil, err
}
defer temp.Close()
defer os.Remove(temp.Name())
sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
......@@ -61,11 +34,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
return nil, err
}
if err := temp.Close(); err != nil {
return nil, err
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
}
status := "using existing layer"
if _, err := os.Stat(blob); err != nil {
status = "creating new layer"
if err := os.Rename(temp.Name(), blob); err != nil {
return nil, err
}
}
return &Layer{
MediaType: mediatype,
Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
Size: n,
tempFileName: temp.Name(),
MediaType: mediatype,
Digest: digest,
Size: n,
status: fmt.Sprintf("%s %s", status, digest),
}, nil
}
......@@ -85,21 +76,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
Digest: digest,
Size: fi.Size(),
From: from,
status: fmt.Sprintf("using existing layer %s", digest),
}, nil
}
func (l *Layer) Commit() (bool, error) {
// always remove temp
defer os.Remove(l.tempFileName)
func (l *Layer) Open() (io.ReadCloser, error) {
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return false, err
}
if _, err := os.Stat(blob); err != nil {
return true, os.Rename(l.tempFileName, blob)
return nil, err
}
return false, nil
return os.Open(blob)
}
......@@ -2,11 +2,56 @@ package server
import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"github.com/ollama/ollama/types/model"
)
type Manifest struct {
ManifestV2
Digest string `json:"-"`
}
func (m *Manifest) Size() (size int64) {
for _, layer := range append(m.Layers, m.Config) {
size += layer.Size
}
return
}
func ParseNamedManifest(name model.Name) (*Manifest, error) {
if !name.IsFullyQualified() {
return nil, model.Unqualified(name)
}
manifests, err := GetManifestPath()
if err != nil {
return nil, err
}
var manifest ManifestV2
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
if err != nil {
return nil, err
}
sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
return nil, err
}
return &Manifest{
ManifestV2: manifest,
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil
}
func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
SchemaVersion: 2,
......
package server
import (
"archive/zip"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/types/model"
)
type layerWithGGML struct {
*Layer
*llm.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String())
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
modelpath = ParseModelPath(name.String())
manifest, _, err = GetManifest(modelpath)
if err != nil {
return nil, err
}
case err != nil:
return nil, err
}
for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
blob, err := os.Open(blobpath)
if err != nil {
return nil, err
}
defer blob.Close()
ggml, _, err := llm.DecodeGGML(blob)
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
default:
layers = append(layers, &layerWithGGML{layer, nil})
}
}
return layers, nil
}
func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
stat, err := file.Stat()
if err != nil {
return nil, err
}
r, err := zip.NewReader(file, stat.Size())
if err != nil {
return nil, err
}
tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempdir)
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
// TODO(mxyng): this should not write out all files to disk
outfile, err := os.Create(filepath.Join(tempdir, f.Name))
if err != nil {
return nil, err
}
defer outfile.Close()
infile, err := f.Open()
if err != nil {
return nil, err
}
defer infile.Close()
if _, err = io.Copy(outfile, infile); err != nil {
return nil, err
}
if err := outfile.Close(); err != nil {
return nil, err
}
if err := infile.Close(); err != nil {
return nil, err
}
}
mf, err := convert.GetModelFormat(tempdir)
if err != nil {
return nil, err
}
params, err := mf.GetParams(tempdir)
if err != nil {
return nil, err
}
mArch, err := mf.GetModelArch("", tempdir, params)
if err != nil {
return nil, err
}
fn(api.ProgressResponse{Status: "processing tensors"})
if err := mArch.GetTensors(); err != nil {
return nil, err
}
if err := mArch.LoadVocab(); err != nil {
return nil, err
}
fn(api.ProgressResponse{Status: "converting model"})
// TODO(mxyng): this should write directly into a layer
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
temp, err := os.CreateTemp(tempdir, "fp16")
if err != nil {
return nil, err
}
defer temp.Close()
defer os.Remove(temp.Name())
if err = mArch.WriteGGUF(temp); err != nil {
return nil, err
}
if _, err := temp.Seek(0, io.SeekStart); err != nil {
return nil, err
}
layer, err := NewLayer(temp, "application/vnd.ollama.image.model")
if err != nil {
return nil, fmt.Errorf("aaa: %w", err)
}
blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
bin, err := os.Open(blobpath)
if err != nil {
return nil, err
}
defer bin.Close()
ggml, _, err := llm.DecodeGGML(bin)
if err != nil {
return nil, err
}
layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
return layers, nil
}
func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
sr := io.NewSectionReader(file, 0, 512)
contentType, err := detectContentType(sr)
if err != nil {
return nil, err
}
switch contentType {
case "gguf", "ggla":
// noop
case "application/zip":
return parseFromZipFile(ctx, file, fn)
default:
return nil, fmt.Errorf("unsupported content type: %s", contentType)
}
stat, err := file.Stat()
if err != nil {
return nil, err
}
var offset int64
for offset < stat.Size() {
ggml, n, err := llm.DecodeGGML(file)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err
}
mediatype := "application/vnd.ollama.image.model"
if ggml.Name() == "ggla" {
mediatype = "application/vnd.ollama.image.adapter"
} else if ggml.KV().Architecture() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
offset = n
}
return layers, nil
}
func detectContentType(r io.Reader) (string, error) {
var b bytes.Buffer
if _, err := io.Copy(&b, r); err != nil {
return "", err
}
if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" {
return contentType, nil
}
if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" {
return contentType, nil
}
return "unknown", nil
}
......@@ -6,6 +6,7 @@ import (
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
)
......@@ -25,9 +26,10 @@ const (
)
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrInvalidDigestFormat = errors.New("invalid digest format")
)
func ParseModelPath(name string) ModelPath {
......@@ -149,6 +151,17 @@ func GetBlobsPath(digest string) (string, error) {
return "", err
}
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if err != nil {
return "", err
}
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(dir, "blobs", digest)
dirPath := filepath.Dir(path)
......
package server
import "testing"
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist
dir, err := os.MkdirTemp("", "ollama-test")
assert.Nil(t, err)
defer os.RemoveAll(dir)
tests := []struct {
name string
digest string
expected string
err error
}{
{
"empty digest",
"",
filepath.Join(dir, "blobs"),
nil,
},
{
"valid with colon",
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"valid with dash",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"digest too short",
"sha256-45640291",
"",
ErrInvalidDigestFormat,
},
{
"digest too long",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
"",
ErrInvalidDigestFormat,
},
{
"digest invalid chars",
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
"",
ErrInvalidDigestFormat,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", dir)
got, err := GetBlobsPath(tc.digest)
assert.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name)
})
}
}
func TestParseModelPath(t *testing.T) {
tests := []struct {
......
package server
import (
"cmp"
"context"
"encoding/json"
"errors"
......@@ -15,11 +16,8 @@ import (
"os"
"os/signal"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
......@@ -31,14 +29,16 @@ import (
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
var mode string = gin.DebugMode
type Server struct {
addr net.Addr
addr net.Addr
sched *Scheduler
}
func init() {
......@@ -53,88 +53,8 @@ func init() {
gin.SetMode(mode)
}
var loaded struct {
mu sync.Mutex
llama *llm.LlamaServer
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
var defaultSessionDuration = 5 * time.Minute
func unload() {
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
needLoad := loaded.llama == nil || // is there a model loaded?
loaded.model != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
loaded.llama.Ping(ctx) != nil
if needLoad {
if loaded.llama != nil {
slog.Info("changing loaded model")
unload()
}
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
}
loaded.model = model.ModelPath
loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = &opts
if err = llama.WaitUntilRunning(); err != nil {
slog.Error("error loading llama server", "error", err)
unload()
return err
}
}
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
unload()
})
}
loaded.expireTimer.Reset(sessionDuration)
return nil
}
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
......@@ -154,9 +74,7 @@ func isSupportedImageType(image []byte) bool {
return slices.Contains(allowedTypes, contentType)
}
func GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
......@@ -224,8 +142,12 @@ func GenerateHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
......@@ -275,7 +197,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset()
if req.Context != nil {
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
......@@ -297,9 +219,6 @@ func GenerateHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
......@@ -331,7 +250,7 @@ func GenerateHandler(c *gin.Context) {
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
......@@ -359,7 +278,7 @@ func GenerateHandler(c *gin.Context) {
Images: images,
Options: opts,
}
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
......@@ -421,10 +340,7 @@ func getDefaultSessionDuration() time.Duration {
return defaultSessionDuration
}
func EmbeddingsHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
......@@ -469,8 +385,12 @@ func EmbeddingsHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
......@@ -480,7 +400,7 @@ func EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
......@@ -493,7 +413,7 @@ func EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func PullModelHandler(c *gin.Context) {
func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
switch {
......@@ -542,7 +462,7 @@ func PullModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func PushModelHandler(c *gin.Context) {
func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest
err := c.ShouldBindJSON(&req)
switch {
......@@ -591,30 +511,19 @@ func PushModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func CreateModelHandler(c *gin.Context) {
func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
if err := ParseModelPath(model).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}
......@@ -623,19 +532,19 @@ func CreateModelHandler(c *gin.Context) {
return
}
var modelfile io.Reader = strings.NewReader(req.Modelfile)
var r io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
mf, err := os.Open(req.Path)
f, err := os.Open(req.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
defer mf.Close()
defer f.Close()
modelfile = mf
r = f
}
commands, err := parser.Parse(modelfile)
modelfile, err := model.ParseFile(r)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
......@@ -651,7 +560,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(req.Quantization), modelfile, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
......@@ -664,7 +573,7 @@ func CreateModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func DeleteModelHandler(c *gin.Context) {
func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
err := c.ShouldBindJSON(&req)
switch {
......@@ -709,7 +618,7 @@ func DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil)
}
func ShowModelHandler(c *gin.Context) {
func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
err := c.ShouldBindJSON(&req)
switch {
......@@ -799,109 +708,115 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
mf, err := ShowModelfile(model)
if err != nil {
return nil, err
}
resp.Modelfile = mf
var sb strings.Builder
fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"")
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName)
fmt.Fprint(&sb, model.String())
resp.Modelfile = sb.String()
return resp, nil
}
func ListModelsHandler(c *gin.Context) {
models := make([]api.ModelResponse, 0)
manifestsPath, err := GetManifestPath()
func (s *Server) ListModelsHandler(c *gin.Context) {
manifests, err := GetManifestPath()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
modelResponse := func(modelName string) (api.ModelResponse, error) {
model, err := GetModel(modelName)
if err != nil {
return api.ModelResponse{}, err
}
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
var models []api.ModelResponse
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
rel, err := filepath.Rel(manifests, path)
if err != nil {
return err
}
return api.ModelResponse{
Model: model.ShortName,
Name: model.ShortName,
Size: model.Size,
Digest: model.Digest,
Details: modelDetails,
}, nil
}
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
return err
} else if hidden {
return nil
}
walkFunc := func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
path, tag := filepath.Split(path)
model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator))
modelPath := strings.Join([]string{model, tag}, ":")
canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/")
n := model.ParseNameFromFilepath(rel)
m, err := ParseNamedManifest(n)
if err != nil {
return err
}
resp, err := modelResponse(canonicalModelPath)
f, err := m.Config.Open()
if err != nil {
slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath))
// nolint: nilerr
return nil
return err
}
defer f.Close()
resp.ModifiedAt = info.ModTime()
models = append(models, resp)
var c ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil {
return err
}
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.Digest,
ModifiedAt: info.ModTime(),
Details: api.ModelDetails{
Format: c.ModelFormat,
Family: c.ModelFamily,
Families: c.ModelFamilies,
ParameterSize: c.ModelType,
QuantizationLevel: c.FileType,
},
})
}
return nil
}
if err := filepath.Walk(manifestsPath, walkFunc); err != nil {
}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
// most recently modified first
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
})
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
func CopyModelHandler(c *gin.Context) {
var req api.CopyRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
func (s *Server) CopyModelHandler(c *gin.Context) {
var r api.CopyRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Source == "" || req.Destination == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
src := model.ParseName(r.Source)
if !src.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
return
}
if err := ParseModelPath(req.Destination).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
dst := model.ParseName(r.Destination)
if !dst.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
return
}
if err := CopyModel(req.Source, req.Destination); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
func HeadBlobHandler(c *gin.Context) {
func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
......@@ -916,7 +831,7 @@ func HeadBlobHandler(c *gin.Context) {
c.Status(http.StatusOK)
}
func CreateBlobHandler(c *gin.Context) {
func (s *Server) CreateBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
......@@ -946,20 +861,9 @@ func CreateBlobHandler(c *gin.Context) {
return
}
if _, err := layer.Commit(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Status(http.StatusCreated)
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {
......@@ -1031,6 +935,11 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
if allowedHost(host) {
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
return
}
......@@ -1043,19 +952,8 @@ func (s *Server) GenerateRoutes() http.Handler {
config := cors.DefaultConfig()
config.AllowWildcard = true
config.AllowBrowserExtensions = true
if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" {
config.AllowOrigins = strings.Split(allowedOrigins, ",")
}
for _, allowOrigin := range defaultAllowOrigins {
config.AllowOrigins = append(config.AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s:*", allowOrigin),
fmt.Sprintf("https://%s:*", allowOrigin),
)
}
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
config.AllowOrigins = envconfig.AllowOrigins
r := gin.Default()
r.Use(
......@@ -1063,27 +961,27 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr),
)
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingsHandler)
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
r.POST("/api/copy", CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", s.CopyModelHandler)
r.DELETE("/api/delete", s.DeleteModelHandler)
r.POST("/api/show", s.ShowModelHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", ListModelsHandler)
r.Handle(method, "/api/tags", s.ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
......@@ -1094,10 +992,11 @@ func (s *Server) GenerateRoutes() http.Handler {
func Serve(ln net.Listener) error {
level := slog.LevelInfo
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
if envconfig.Debug {
level = slog.LevelDebug
}
slog.Info("server config", "env", envconfig.AsMap())
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
......@@ -1121,7 +1020,7 @@ func Serve(ln net.Listener) error {
return err
}
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
if !envconfig.NoPrune {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
......@@ -1137,7 +1036,9 @@ func Serve(ln net.Listener) error {
}
}
s := &Server{addr: ln.Addr()}
ctx, done := context.WithCancel(context.Background())
sched := InitScheduler(ctx)
s := &Server{addr: ln.Addr(), sched: sched}
r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
......@@ -1150,7 +1051,9 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
unload()
srvr.Close()
done()
sched.unloadAllRunners()
gpu.Cleanup()
os.Exit(0)
}()
......@@ -1158,12 +1061,12 @@ func Serve(ln net.Listener) error {
if err := llm.Init(); err != nil {
return fmt.Errorf("unable to initialize llm library %w", err)
}
if runtime.GOOS == "linux" { // TODO - windows too
// check compatibility to log warnings
if _, err := gpu.CheckVRAM(); err != nil {
slog.Info(err.Error())
}
}
s.sched.Run(ctx)
// At startup we retrieve GPU information so we can get log messages before loading a model
// This will log warnings to the log in case we have problems with detected GPUs
_ = gpu.GetGPUInfo()
return srvr.Serve(ln)
}
......@@ -1219,9 +1122,9 @@ func streamResponse(c *gin.Context, ch chan any) {
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.llama.Tokenize(ctx, s)
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
......@@ -1232,10 +1135,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
return prompt, nil
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
......@@ -1292,8 +1192,12 @@ func ChatHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
......@@ -1309,7 +1213,7 @@ func ChatHandler(c *gin.Context) {
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
......@@ -1352,8 +1256,6 @@ func ChatHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
......@@ -1376,7 +1278,7 @@ func ChatHandler(c *gin.Context) {
ch <- resp
}
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
......@@ -1416,3 +1318,15 @@ func ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
func handleErrorResponse(c *gin.Context, err error) {
if errors.Is(err, context.Canceled) {
c.JSON(499, gin.H{"error": "request canceled"})
return
}
if errors.Is(err, ErrMaxQueue) {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
......@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
......@@ -55,13 +55,13 @@ func Test_Routes(t *testing.T) {
createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model")
modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
commands, err := parser.Parse(modelfile)
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
modelfile, err := model.ParseFile(r)
assert.Nil(t, err)
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), name, "", "", commands, fn)
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
assert.Nil(t, err)
}
......@@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) {
Method: http.MethodPost,
Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) {
f, err := os.CreateTemp(t.TempDir(), "ollama-model")
assert.Nil(t, err)
defer f.Close()
fname := createTestFile(t, "ollama-model")
stream := false
createReq := api.CreateRequest{
Name: "t-bone",
Modelfile: fmt.Sprintf("FROM %s", f.Name()),
Modelfile: fmt.Sprintf("FROM %s", fname),
Stream: &stream,
}
jsonData, err := json.Marshal(createReq)
......@@ -216,28 +214,25 @@ func Test_Routes(t *testing.T) {
httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close)
workDir, err := os.MkdirTemp("", "ollama-test")
assert.Nil(t, err)
defer os.RemoveAll(workDir)
os.Setenv("OLLAMA_MODELS", workDir)
t.Setenv("OLLAMA_MODELS", t.TempDir())
for _, tc := range testCases {
t.Logf("Running Test: [%s]", tc.Name)
u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
assert.Nil(t, err)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp, err := httpSrv.Client().Do(req)
assert.Nil(t, err)
defer resp.Body.Close()
if tc.Expected != nil {
tc.Expected(t, resp)
}
t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
assert.Nil(t, err)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp, err := httpSrv.Client().Do(req)
assert.Nil(t, err)
defer resp.Body.Close()
if tc.Expected != nil {
tc.Expected(t, resp)
}
})
}
}
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"reflect"
"sort"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"golang.org/x/exp/slices"
)
type LlmRequest struct {
ctx context.Context //nolint:containedctx
model *Model
opts api.Options
sessionDuration time.Duration
successCh chan *runnerRef
errCh chan error
}
type Scheduler struct {
pendingReqCh chan *LlmRequest
finishedReqCh chan *LlmRequest
expiredCh chan *runnerRef
unloadedCh chan interface{}
loaded map[string]*runnerRef
loadedMu sync.Mutex
loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList)
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
getGpuFn func() gpu.GpuInfoList
}
var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
func InitScheduler(ctx context.Context) *Scheduler {
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests),
expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests),
unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
}
sched.loadFn = sched.load
return sched
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
// allocate a large enough kv cache for all parallel requests
opts.NumCtx = opts.NumCtx * envconfig.NumParallel
req := &LlmRequest{
ctx: c,
model: model,
opts: opts,
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef),
errCh: make(chan error, 1),
}
select {
case s.pendingReqCh <- req:
default:
req.errCh <- ErrMaxQueue
}
return req.successCh, req.errCh
}
// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
func (s *Scheduler) Run(ctx context.Context) {
slog.Debug("starting llm scheduler")
go func() {
s.processPending(ctx)
}()
go func() {
s.processCompleted(ctx)
}()
}
func (s *Scheduler) processPending(ctx context.Context) {
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running
if pending.ctx.Err() != nil {
slog.Debug("pending request cancelled or timed out, skipping scheduling")
continue
}
for {
var runnerToExpire *runnerRef
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
loadedCount := len(s.loaded)
s.loadedMu.Unlock()
if runner != nil {
if runner.needsReload(ctx, pending) {
runnerToExpire = runner
} else {
// Runner is usable, return it
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
} else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload()
} else {
// Either no models are loaded or below envconfig.MaxRunners
// Get a refreshed GPU list
gpus := s.getGpuFn()
// Load model for fitting
ggml, err := llm.LoadModel(pending.model.ModelPath)
if err != nil {
pending.errCh <- err
break
}
// If we're CPU only mode, just limit by envconfig.MaxRunners above
// TODO handle system memory exhaustion
if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
slog.Debug("cpu mode with existing models, loading")
s.loadFn(pending, ggml, gpus)
break
}
// No models loaded. Load the model but prefer the best fit.
if loadedCount == 0 {
slog.Debug("loading first model", "model", pending.model.ModelPath)
g := pickBestFitGPUs(pending, ggml, gpus)
if g != nil {
gpus = g
}
s.loadFn(pending, ggml, gpus)
break
}
// More than one loaded model, so we have to see if the new one fits
// Update free memory from currently loaded models
s.updateFreeSpace(gpus)
gpus = pickBestFitGPUs(pending, ggml, gpus)
if gpus != nil {
slog.Debug("new model fits with existing models, loading")
s.loadFn(pending, ggml, gpus)
break
}
runnerToExpire = s.findRunnerToUnload()
}
if runnerToExpire == nil {
// Shouildn't happen
slog.Error("runner to expire was nil!")
continue
}
// Trigger an expiration to unload once it's done
runnerToExpire.refMu.Lock()
slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
if runnerToExpire.expireTimer != nil {
runnerToExpire.expireTimer.Stop()
runnerToExpire.expireTimer = nil
}
runnerToExpire.sessionDuration = 0
if runnerToExpire.refCount <= 0 {
s.expiredCh <- runnerToExpire
}
runnerToExpire.refMu.Unlock()
// Wait for the unload to happen
// Note: at this point we're queueing up all incoming requests, even if they were for
// a different model that's loaded and not scheduled to be removed.
slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case <-s.unloadedCh:
slog.Debug("unload completed", "model", runnerToExpire.model)
continue
}
}
case <-s.unloadedCh:
// An unload request when there are no pending request can be ignored
slog.Debug("ignoring unload event with no pending requests")
}
}
}
func (s *Scheduler) processCompleted(ctx context.Context) {
// Process completed requests, expired timers, and unloading models
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
continue
}
runner.refMu.Lock()
runner.refCount--
if runner.refCount <= 0 {
if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
} else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
slog.Debug("timer expired, expiring to unload", "model", runner.model)
runner.refMu.Lock()
defer runner.refMu.Unlock()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
})
} else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration)
}
}
slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
runner.refMu.Unlock()
case runner := <-s.expiredCh:
slog.Debug("runner expired event received", "model", runner.model)
runner.refMu.Lock()
if runner.refCount > 0 {
// Shouldn't happen, but safeguard to ensure no leaked runners
slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event
time.Sleep(10 * time.Millisecond)
s.expiredCh <- runner
}(runner)
runner.refMu.Unlock()
continue
}
s.loadedMu.Lock()
slog.Debug("got lock to unload", "model", runner.model)
runner.unload()
delete(s.loaded, runner.model)
s.loadedMu.Unlock()
slog.Debug("runner released", "model", runner.model)
runner.refMu.Unlock()
slog.Debug("sending an unloaded event", "model", runner.model)
s.unloadedCh <- struct{}{}
}
}
}
// Complete the pending request and send the runner back to the requester
// Wires up a finished event after the request context is completed
// Updates session duration, and resets expiration timer
func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
runner.refMu.Lock()
defer runner.refMu.Unlock()
runner.refCount++
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
runner.sessionDuration = pending.sessionDuration
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
slog.Debug("context for request finished")
finished <- pending
}()
}
func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) {
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err
return
}
runner := &runnerRef{}
runner.model = req.model.ModelPath
runner.adapters = req.model.AdapterPaths
runner.projectors = req.model.ProjectorPaths
runner.llama = llama
runner.Options = &req.opts
runner.sessionDuration = req.sessionDuration
runner.gpus = gpus
runner.estimatedVRAM = llama.EstimatedVRAM()
runner.loading = true
runner.refCount = 1
runner.refMu.Lock()
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
go func() {
defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err)
runner.refCount--
req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.model)
s.expiredCh <- runner
return
}
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
runner.loading = false
go func() {
<-req.ctx.Done()
slog.Debug("context for request finished")
s.finishedReqCh <- req
}()
req.successCh <- runner
}()
}
func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
type predKey struct {
Library string
ID string
}
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
for _, r := range s.loaded {
r.refMu.Lock()
gpuIDs := make([]string, 0, len(r.gpus))
if r.llama != nil {
// TODO this should be broken down by GPU instead of assuming uniform spread
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
for _, gpu := range r.gpus {
gpuIDs = append(gpuIDs, gpu.ID)
}
for _, gpu := range allGpus {
if slices.Contains(gpuIDs, gpu.ID) {
predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
}
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
}
r.refMu.Unlock()
}
s.loadedMu.Unlock()
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
if p > allGpus[i].TotalMemory {
// Shouldn't happen
slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
allGpus[i].FreeMemory = 0
} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
// and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
}
slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
}
}
}
type runnerRef struct {
refMu sync.Mutex
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer
loading bool // True only during initial load, then false forever
gpus gpu.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64
sessionDuration time.Duration
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
// The refMu must already be held when calling unload
func (runner *runnerRef) unload() {
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
if runner.llama != nil {
runner.llama.Close()
}
runner.llama = nil
runner.adapters = nil
runner.projectors = nil
runner.Options = nil
runner.gpus = nil
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
runner.refMu.Lock()
defer runner.refMu.Unlock()
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
}
if runner.Options == nil {
return true
}
// Don't reload runner if num_gpu=-1 was provided
optsExisting := runner.Options.Runner
optsNew := req.opts.Runner
if optsNew.NumGPU < 0 {
optsExisting.NumGPU = -1
optsNew.NumGPU = -1
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}
return false
}
type ByDuration []*runnerRef
func (a ByDuration) Len() int { return len(a) }
func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByDuration) Less(i, j int) bool {
// uint64 to turn negative time (never unload) to largest
return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
}
// TODO - future consideration to pick runners based on size
// type BySize []*runnerRef
// func (a BySize) Len() int { return len(a) }
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
// If the model can not be fit fully within the available GPU(s) nil is returned
func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList {
var estimatedVRAM uint64
for _, gl := range gpus.ByLibrary() {
var ok bool
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc.
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
// First attempt to fit the model into a single GPU
for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
return []gpu.GpuInfo{g}
}
}
// TODO future refinements
// - if multiple Libraries, see if any single GPU in any Library will fit
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUs
if ok, estimatedVRAM = llm.PredictServerFit(gl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
return gl
}
}
return nil
}
// findRunnerToUnload finds a runner to unload to make room for a new model
func (s *Scheduler) findRunnerToUnload() *runnerRef {
s.loadedMu.Lock()
runnerList := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runnerList = append(runnerList, r)
}
s.loadedMu.Unlock()
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
// e.g., if we have multiple options, will one make room for the request?
sort.Sort(ByDuration(runnerList))
// First try to find a runner that's already idle
for _, runner := range runnerList {
runner.refMu.Lock()
rc := runner.refCount
runner.refMu.Unlock()
if rc == 0 {
slog.Debug("found an idle runner to unload")
return runner
}
}
// None appear idle, just wait for the one with the shortest duration
slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
return runnerList[0]
}
func (s *Scheduler) unloadAllRunners() {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
for model, runner := range s.loaded {
if runner.llama != nil {
slog.Debug("shutting down runner", "model", model)
runner.llama.Close()
}
}
}
package server
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"log/slog"
"os"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
os.Setenv("OLLAMA_DEBUG", "1")
lifecycle.InitLogging()
}
func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background())
defer done()
s := InitScheduler(ctx)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
}
func TestLoad(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer done()
s := InitScheduler(ctx)
var ggml *llm.GGML // value not used in tests
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
}
// Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return nil, fmt.Errorf("something failed to load model blah")
}
gpus := gpu.GpuInfoList{}
s.load(req, ggml, gpus)
require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
s.loadedMu.Unlock()
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
server := &mockLlm{estimatedVRAM: 10}
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return server, nil
}
s.load(req, ggml, gpus)
select {
case err := <-req.errCh:
require.NoError(t, err)
case resp := <-req.successCh:
require.Equal(t, uint64(10), resp.estimatedVRAM)
require.Equal(t, uint(1), resp.refCount)
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
}
req.model.ModelPath = "dummy_model_path"
server.waitResp = fmt.Errorf("wait failure")
s.load(req, ggml, gpus)
select {
case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure")
case resp := <-req.successCh:
t.Errorf("unexpected success %v", resp)
}
s.loadedMu.Lock()
runner := s.loaded["dummy_model_path"]
s.loadedMu.Unlock()
require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount)
time.Sleep(1 * time.Millisecond)
require.Len(t, s.expiredCh, 1)
}
type bundle struct {
ctx context.Context //nolint:containedctx
ctxDone func()
srv *mockLlm
req *LlmRequest
ggml *llm.GGML
}
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return scenario.srv, nil
}
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
assert.Nil(t, err)
defer f.Close()
gguf := llm.NewGGUFV3(binary.LittleEndian)
err = gguf.Encode(f, llm.KV{
"general.architecture": "llama",
"general.name": "name",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
})
assert.Nil(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath)
require.NoError(t, err)
scenario.req = &LlmRequest{
ctx: scenario.ctx,
model: model,
opts: api.DefaultOptions(),
sessionDuration: 5 * time.Millisecond,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
return scenario
}
func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
scenario2a.req.model = scenario1a.req.model
scenario2a.ggml = scenario1a.ggml
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Same runner as first request due to not needing a reload
s.newServerFn = scenario1b.newServer
slog.Info("scenario1b")
s.pendingReqCh <- scenario1b.req
select {
case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Trigger a reload
s.newServerFn = scenario2a.newServer
scenario2a.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a")
s.pendingReqCh <- scenario2a.req
// finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond)
scenario1a.ctxDone()
scenario1b.ctxDone()
select {
case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario2a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
envconfig.MaxRunners = 1
s.newServerFn = scenario3a.newServer
slog.Info("scenario3a")
s.pendingReqCh <- scenario3a.req
// finish prior request, so new model can load
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select {
case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
envconfig.MaxRunners = 0
s.newServerFn = scenario3b.newServer
slog.Info("scenario3b")
s.pendingReqCh <- scenario3b.req
select {
case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
// This is a CPU load with NumGPU = 0 so it should load
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
s.pendingReqCh <- scenario3c.req
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3c.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 3)
s.loadedMu.Unlock()
// Try to load a model that wont fit
s.newServerFn = scenario3d.newServer
slog.Info("scenario3d")
s.loadedMu.Lock()
require.Len(t, s.loaded, 3)
s.loadedMu.Unlock()
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3d.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
scenario3b.ctxDone()
select {
case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3d.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
}
func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0
envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Len(t, successCh1b, 0)
require.Len(t, errCh1b, 1)
err := <-errCh1b
require.Contains(t, err.Error(), "server busy")
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
scenario1a.ctxDone()
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(5 * time.Millisecond)
require.Len(t, successCh1c, 0)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
s.loadedMu.Unlock()
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
scenario1b.ctxDone()
}
// TODO - add one scenario that triggers the bogus finished event with positive ref count
func TestPrematureExpired(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case <-ctx.Done():
t.Errorf("timeout")
}
time.Sleep(scenario1a.req.sessionDuration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond)
require.Len(t, s.finishedReqCh, 0)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
s.loadedMu.Unlock()
// also shouldn't happen in real life
s.finishedReqCh <- scenario1a.req
time.Sleep(5 * time.Millisecond)
}
func TestUseLoadedRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
req := &LlmRequest{
ctx: ctx,
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1, sessionDuration: 1}
req.useLoadedRunner(r1, finished)
require.Equal(t, uint(1), r1.refCount)
require.Equal(t, time.Duration(2), r1.sessionDuration)
select {
case success := <-req.successCh:
require.Equal(t, r1, success)
case <-ctx.Done():
t.Errorf("timeout")
}
done()
fin := <-finished
require.Equal(t, req, fin)
}
func TestUpdateFreeSpace(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
gpus := gpu.GpuInfoList{
{
Library: "a",
ID: "1",
},
{
Library: "a",
ID: "2",
},
}
gpus[0].TotalMemory = 1000
gpus[0].FreeMemory = 900
gpus[1].TotalMemory = 2000
gpus[1].FreeMemory = 1900
llm1 := &mockLlm{estimatedVRAM: 100}
llm2 := &mockLlm{estimatedVRAM: 200}
r1 := &runnerRef{llama: llm1, gpus: gpus}
r2 := &runnerRef{llama: llm2, gpus: gpus}
s := InitScheduler(ctx)
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
s.updateFreeSpace(gpus)
require.Equal(t, uint64(850), gpus[0].FreeMemory)
require.Equal(t, uint64(1850), gpus[1].FreeMemory)
}
func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
r2 := &runnerRef{sessionDuration: 2}
s := InitScheduler(ctx)
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
resp := s.findRunnerToUnload()
require.Equal(t, r2, resp)
r2.refCount = 1
resp = s.findRunnerToUnload()
require.Equal(t, r1, resp)
}
func TestNeedsReload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
llm := &mockLlm{}
do := api.DefaultOptions()
runner := &runnerRef{
adapters: []string{"adapter1"},
projectors: []string{"projector1"},
Options: &do,
llama: llm,
}
req := &LlmRequest{
model: &Model{
AdapterPaths: []string{"adapter2"},
ProjectorPaths: []string{"projector2"},
},
opts: api.DefaultOptions(),
}
resp := runner.needsReload(ctx, req)
require.True(t, resp)
req.model.AdapterPaths = runner.adapters
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.model.ProjectorPaths = runner.projectors
runner.loading = true
req.opts.NumBatch = 1234
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumBatch = runner.Options.NumBatch
llm.pingResp = fmt.Errorf("foo")
resp = runner.needsReload(ctx, req)
require.True(t, resp)
llm.pingResp = nil
resp = runner.needsReload(ctx, req)
require.False(t, resp)
req.opts.NumGPU = 99
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumGPU = -1
resp = runner.needsReload(ctx, req)
require.False(t, resp)
}
func TestUnloadAllRunners(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
llm1 := &mockLlm{}
llm2 := &mockLlm{}
s := InitScheduler(ctx)
s.unloadAllRunners()
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{llama: llm2}
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
s.unloadAllRunners()
require.True(t, llm1.closeCalled)
require.True(t, llm2.closeCalled)
}
func TestUnload(t *testing.T) {
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{adapters: []string{"A"}}
r1.unload()
require.True(t, llm1.closeCalled)
r2.unload()
require.Nil(t, r2.adapters)
}
type mockLlm struct {
pingResp error
waitResp error
completionResp error
embeddingResp []float64
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
detokenizeResp string
detonekizeRespErr error
closeResp error
closeCalled bool
estimatedVRAM uint64
}
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr
}
func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
return s.detokenizeResp, s.detonekizeRespErr
}
func (s *mockLlm) Close() error {
s.closeCalled = true
return s.closeResp
}
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
// Package errtypes contains custom error types
package errtypes
import (
"fmt"
"strings"
)
const UnknownOllamaKeyErrMsg = "unknown ollama key"
// TODO: This should have a structured response from the API
type UnknownOllamaKey struct {
Key string
}
func (e *UnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
}
package model
import (
"log/slog"
"strings"
"unicode"
)
// Digest represents a digest of a model Manifest. It is a comparable value
// type and is immutable.
//
// The zero Digest is not a valid digest.
type Digest struct {
s string
}
// Type returns the digest type of the digest.
//
// Example:
//
// ParseDigest("sha256-1234").Type() // returns "sha256"
func (d Digest) Type() string {
typ, _, _ := strings.Cut(d.s, "-")
return typ
}
// String returns the digest in the form of "<digest-type>-<digest>", or the
// empty string if the digest is invalid.
func (d Digest) String() string { return d.s }
// IsValid returns true if the digest is valid (not zero).
//
// A valid digest may be created only by ParseDigest, or
// ParseName(name).Digest().
func (d Digest) IsValid() bool { return d.s != "" }
// LogValue implements slog.Value.
func (d Digest) LogValue() slog.Value {
return slog.StringValue(d.String())
}
var (
_ slog.LogValuer = Digest{}
)
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
// Digest.
func ParseDigest(s string) Digest {
typ, digest, ok := strings.Cut(s, "-")
if ok && isValidDigestType(typ) && isValidHex(digest) {
return Digest{s: s}
}
return Digest{}
}
func isValidDigestType(s string) bool {
if len(s) == 0 {
return false
}
for _, r := range s {
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
return false
}
}
return true
}
func isValidHex(s string) bool {
if len(s) == 0 {
return false
}
for i := range s {
c := s[i]
if c < '0' || c > '9' && c < 'a' || c > 'f' {
return false
}
}
return true
}
package model
import "testing"
var testDigests = map[string]Digest{
"": {},
"sha256-1234": {s: "sha256-1234"},
"sha256-5678": {s: "sha256-5678"},
"blake2-9abc": {s: "blake2-9abc"},
"-1234": {},
"sha256-": {},
"sha256-1234-5678": {},
"sha256-P": {}, // invalid hex
"sha256-1234P": {},
"---": {},
}
func TestDigestParse(t *testing.T) {
// Test cases.
for s, want := range testDigests {
got := ParseDigest(s)
t.Logf("ParseDigest(%q) = %#v", s, got)
if got != want {
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
}
}
}
func TestDigestString(t *testing.T) {
// Test cases.
for s, d := range testDigests {
want := s
if !d.IsValid() {
want = ""
}
got := d.String()
if got != want {
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
}
got = ParseDigest(s).String()
if got != want {
t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment