Unverified Commit 5c191276 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

Merge pull request #5473 from ollama/mxyng/environ

fix: environ lookup
parents 71399aa6 85d9d73a
...@@ -20,7 +20,6 @@ import ( ...@@ -20,7 +20,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"runtime" "runtime"
...@@ -63,13 +62,8 @@ func checkError(resp *http.Response, body []byte) error { ...@@ -63,13 +62,8 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be // If the variable is not specified, a default ollama host and port will be
// used. // used.
func ClientFromEnvironment() (*Client, error) { func ClientFromEnvironment() (*Client, error) {
ollamaHost := envconfig.Host
return &Client{ return &Client{
base: &url.URL{ base: envconfig.Host(),
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
},
http: http.DefaultClient, http: http.DefaultClient,
}, nil }, nil
} }
......
...@@ -2,8 +2,6 @@ package api ...@@ -2,8 +2,6 @@ package api
import ( import (
"testing" "testing"
"github.com/ollama/ollama/envconfig"
) )
func TestClientFromEnvironment(t *testing.T) { func TestClientFromEnvironment(t *testing.T) {
...@@ -33,7 +31,6 @@ func TestClientFromEnvironment(t *testing.T) { ...@@ -33,7 +31,6 @@ func TestClientFromEnvironment(t *testing.T) {
for k, v := range testCases { for k, v := range testCases {
t.Run(k, func(t *testing.T) { t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value) t.Setenv("OLLAMA_HOST", v.value)
envconfig.LoadConfig()
client, err := ClientFromEnvironment() client, err := ClientFromEnvironment()
if err != v.err { if err != v.err {
......
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
func InitLogging() { func InitLogging() {
level := slog.LevelInfo level := slog.LevelInfo
if envconfig.Debug { if envconfig.Debug() {
level = slog.LevelDebug level = slog.LevelDebug
} }
......
...@@ -1076,7 +1076,7 @@ func RunServer(cmd *cobra.Command, _ []string) error { ...@@ -1076,7 +1076,7 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return err return err
} }
ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port)) ln, err := net.Listen("tcp", envconfig.Host().Host)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -160,7 +160,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { ...@@ -160,7 +160,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
if envconfig.NoHistory { if envconfig.NoHistory() {
scanner.HistoryDisable() scanner.HistoryDisable()
} }
......
package envconfig package envconfig
import ( import (
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
...@@ -14,347 +14,271 @@ import ( ...@@ -14,347 +14,271 @@ import (
"time" "time"
) )
type OllamaHost struct { // Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable.
Scheme string // Default is scheme "http" and host "127.0.0.1:11434"
Host string func Host() *url.URL {
Port string defaultPort := "11434"
}
func (o OllamaHost) String() string {
return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port)
}
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") s := strings.TrimSpace(Var("OLLAMA_HOST"))
scheme, hostport, ok := strings.Cut(s, "://")
switch {
case !ok:
scheme, hostport = "http", s
case scheme == "http":
defaultPort = "80"
case scheme == "https":
defaultPort = "443"
}
var ( // trim trailing slashes
// Set via OLLAMA_ORIGINS in the environment hostport = strings.TrimRight(hostport, "/")
AllowOrigins []string
// Set via OLLAMA_DEBUG in the environment
Debug bool
// Experimental flash attention
FlashAttention bool
// Set via OLLAMA_HOST in the environment
Host *OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive time.Duration
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int
// Set via OLLAMA_MODELS in the environment
ModelsDir string
// Set via OLLAMA_NOHISTORY in the environment
NoHistory bool
// Set via OLLAMA_NOPRUNE in the environment
NoPrune bool
// Set via OLLAMA_NUM_PARALLEL in the environment
NumParallel int
// Set via OLLAMA_RUNNERS_DIR in the environment
RunnersDir string
// Set via OLLAMA_SCHED_SPREAD in the environment
SchedSpread bool
// Set via OLLAMA_TMPDIR in the environment
TmpDir string
// Set via OLLAMA_INTEL_GPU in the environment
IntelGpu bool
// Set via CUDA_VISIBLE_DEVICES in the environment
CudaVisibleDevices string
// Set via HIP_VISIBLE_DEVICES in the environment
HipVisibleDevices string
// Set via ROCR_VISIBLE_DEVICES in the environment
RocrVisibleDevices string
// Set via GPU_DEVICE_ORDINAL in the environment
GpuDeviceOrdinal string
// Set via HSA_OVERRIDE_GFX_VERSION in the environment
HsaOverrideGfxVersion string
)
type EnvVar struct { host, port, err := net.SplitHostPort(hostport)
Name string if err != nil {
Value any host, port = "127.0.0.1", defaultPort
Description string if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
} host = ip.String()
} else if hostport != "" {
host = hostport
}
}
func AsMap() map[string]EnvVar { if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 {
ret := map[string]EnvVar{ slog.Warn("invalid port, using default", "port", port, "default", defaultPort)
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, return &url.URL{
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, Scheme: scheme,
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, Host: net.JoinHostPort(host, defaultPort),
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
} }
if runtime.GOOS != "darwin" {
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices, "Set which NVIDIA devices are visible"}
ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices, "Set which AMD devices are visible"}
ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGpu, "Enable experimental Intel GPU detection"}
} }
return ret
}
func Values() map[string]string { return &url.URL{
vals := make(map[string]string) Scheme: scheme,
for k, v := range AsMap() { Host: net.JoinHostPort(host, port),
vals[k] = fmt.Sprintf("%v", v.Value)
} }
return vals
} }
var defaultAllowOrigins = []string{ // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
"localhost", func Origins() (origins []string) {
"127.0.0.1", if s := Var("OLLAMA_ORIGINS"); s != "" {
"0.0.0.0", origins = strings.Split(s, ",")
} }
// Clean quotes and spaces from the value for _, origin := range []string{"localhost", "127.0.0.1", "0.0.0.0"} {
func clean(key string) string { origins = append(origins,
return strings.Trim(os.Getenv(key), "\"' ") fmt.Sprintf("http://%s", origin),
} fmt.Sprintf("https://%s", origin),
fmt.Sprintf("http://%s", net.JoinHostPort(origin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(origin, "*")),
)
}
func init() { origins = append(origins,
// default values "app://*",
NumParallel = 0 // Autoselect "file://*",
MaxRunners = 0 // Autoselect "tauri://*",
MaxQueuedRequests = 512 )
KeepAlive = 5 * time.Minute
LoadConfig() return origins
} }
func LoadConfig() { // Models returns the path to the models directory. Models directory can be configured via the OLLAMA_MODELS environment variable.
if debug := clean("OLLAMA_DEBUG"); debug != "" { // Default is $HOME/.ollama/models
d, err := strconv.ParseBool(debug) func Models() string {
if err == nil { if s := Var("OLLAMA_MODELS"); s != "" {
Debug = d return s
} else {
Debug = true
}
} }
if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" { home, err := os.UserHomeDir()
d, err := strconv.ParseBool(fa)
if err == nil {
FlashAttention = d
}
}
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
if runtime.GOOS == "windows" && RunnersDir == "" {
// On Windows we do not carry the payloads inside the main executable
appExe, err := os.Executable()
if err != nil {
slog.Error("failed to lookup executable path", "error", err)
}
cwd, err := os.Getwd()
if err != nil { if err != nil {
slog.Error("failed to lookup working directory", "error", err) panic(err)
} }
var paths []string return filepath.Join(home, ".ollama", "models")
for _, root := range []string{filepath.Dir(appExe), cwd} { }
paths = append(paths,
root,
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
}
// Try a few variations to improve developer experience when building from source in the local tree // KeepAlive returns the duration that models stay loaded in memory. KeepAlive can be configured via the OLLAMA_KEEP_ALIVE environment variable.
for _, p := range paths { // Negative values are treated as infinite. Zero is treated as no keep alive.
candidate := filepath.Join(p, "ollama_runners") // Default is 5 minutes.
_, err := os.Stat(candidate) func KeepAlive() (keepAlive time.Duration) {
if err == nil { keepAlive = 5 * time.Minute
RunnersDir = candidate if s := Var("OLLAMA_KEEP_ALIVE"); s != "" {
break if d, err := time.ParseDuration(s); err == nil {
} keepAlive = d
} } else if n, err := strconv.ParseInt(s, 10, 64); err == nil {
if RunnersDir == "" { keepAlive = time.Duration(n) * time.Second
slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
} }
} }
TmpDir = clean("OLLAMA_TMPDIR") if keepAlive < 0 {
return time.Duration(math.MaxInt64)
}
LLMLibrary = clean("OLLAMA_LLM_LIBRARY") return keepAlive
}
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { func Bool(k string) func() bool {
val, err := strconv.Atoi(onp) return func() bool {
if s := Var(k); s != "" {
b, err := strconv.ParseBool(s)
if err != nil { if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err) return true
} else {
NumParallel = val
}
} }
if nohistory := clean("OLLAMA_NOHISTORY"); nohistory != "" { return b
NoHistory = true
} }
if spread := clean("OLLAMA_SCHED_SPREAD"); spread != "" { return false
s, err := strconv.ParseBool(spread)
if err == nil {
SchedSpread = s
} else {
SchedSpread = true
}
} }
}
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" { var (
NoPrune = true // Debug enabled additional debug information.
} Debug = Bool("OLLAMA_DEBUG")
// FlashAttention enables the experimental flash attention feature.
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
// NoHistory disables readline history.
NoHistory = Bool("OLLAMA_NOHISTORY")
// NoPrune disables pruning of model blobs on startup.
NoPrune = Bool("OLLAMA_NOPRUNE")
// SchedSpread allows scheduling models across all GPUs.
SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
// IntelGPU enables experimental Intel GPU detection.
IntelGPU = Bool("OLLAMA_INTEL_GPU")
)
if origins := clean("OLLAMA_ORIGINS"); origins != "" { func String(s string) func() string {
AllowOrigins = strings.Split(origins, ",") return func() string {
} return Var(s)
for _, allowOrigin := range defaultAllowOrigins {
AllowOrigins = append(AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
)
} }
}
AllowOrigins = append(AllowOrigins, var (
"app://*", LLMLibrary = String("OLLAMA_LLM_LIBRARY")
"file://*", TmpDir = String("OLLAMA_TMPDIR")
"tauri://*",
) CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = String("HIP_VISIBLE_DEVICES")
RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES")
GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL")
HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION")
)
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") func RunnersDir() (p string) {
if maxRunners != "" { if p := Var("OLLAMA_RUNNERS_DIR"); p != "" {
m, err := strconv.Atoi(maxRunners) return p
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
MaxRunners = m
}
} }
if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { if runtime.GOOS != "windows" {
p, err := strconv.Atoi(onp) return
if err != nil || p <= 0 {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err)
} else {
MaxQueuedRequests = p
}
} }
ka := clean("OLLAMA_KEEP_ALIVE") defer func() {
if ka != "" { if p == "" {
loadKeepAlive(ka) slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'")
} }
}()
var err error // On Windows we do not carry the payloads inside the main executable
ModelsDir, err = getModelsDir() exe, err := os.Executable()
if err != nil { if err != nil {
slog.Error("invalid setting", "OLLAMA_MODELS", ModelsDir, "error", err) return
} }
Host, err = getOllamaHost() cwd, err := os.Getwd()
if err != nil { if err != nil {
slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port) return
} }
if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil { var paths []string
IntelGpu = set for _, root := range []string{filepath.Dir(exe), cwd} {
paths = append(paths,
root,
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
} }
CudaVisibleDevices = clean("CUDA_VISIBLE_DEVICES") // Try a few variations to improve developer experience when building from source in the local tree
HipVisibleDevices = clean("HIP_VISIBLE_DEVICES") for _, path := range paths {
RocrVisibleDevices = clean("ROCR_VISIBLE_DEVICES") candidate := filepath.Join(path, "ollama_runners")
GpuDeviceOrdinal = clean("GPU_DEVICE_ORDINAL") if _, err := os.Stat(candidate); err == nil {
HsaOverrideGfxVersion = clean("HSA_OVERRIDE_GFX_VERSION") p = candidate
} break
func getModelsDir() (string, error) {
if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists {
return models, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
} }
return filepath.Join(home, ".ollama", "models"), nil
}
func getOllamaHost() (*OllamaHost, error) {
defaultPort := "11434"
hostVar := os.Getenv("OLLAMA_HOST")
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
scheme, hostport, ok := strings.Cut(hostVar, "://")
switch {
case !ok:
scheme, hostport = "http", hostVar
case scheme == "http":
defaultPort = "80"
case scheme == "https":
defaultPort = "443"
} }
// trim trailing slashes return p
hostport = strings.TrimRight(hostport, "/") }
host, port, err := net.SplitHostPort(hostport) func Uint(key string, defaultValue uint) func() uint {
if err != nil { return func() uint {
host, port = "127.0.0.1", defaultPort if s := Var(key); s != "" {
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { if n, err := strconv.ParseUint(s, 10, 64); err != nil {
host = ip.String() slog.Warn("invalid environment variable, using default", "key", key, "value", s, "default", defaultValue)
} else if hostport != "" { } else {
host = hostport return uint(n)
} }
} }
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { return defaultValue
return &OllamaHost{
Scheme: scheme,
Host: host,
Port: defaultPort,
}, ErrInvalidHostPort
} }
}
return &OllamaHost{ var (
Scheme: scheme, // NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
Host: host, NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
Port: port, // MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
}, nil MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
)
type EnvVar struct {
Name string
Value any
Description string
} }
func loadKeepAlive(ka string) { func AsMap() map[string]EnvVar {
v, err := strconv.Atoi(ka) ret := map[string]EnvVar{
if err != nil { "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
d, err := time.ParseDuration(ka) "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
if err == nil { "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
if d < 0 { "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
KeepAlive = time.Duration(math.MaxInt64) "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
} else { "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
KeepAlive = d "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
} "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"},
} }
} else { if runtime.GOOS != "darwin" {
d := time.Duration(v) * time.Second ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}
if d < 0 { ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible"}
KeepAlive = time.Duration(math.MaxInt64) ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible"}
} else { ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible"}
KeepAlive = d ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
} }
return ret
}
func Values() map[string]string {
vals := make(map[string]string)
for k, v := range AsMap() {
vals[k] = fmt.Sprintf("%v", v.Value)
} }
return vals
}
// Var returns an environment variable stripped of leading and trailing quotes or spaces
func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
} }
package envconfig package envconfig
import ( import (
"fmt"
"math" "math"
"net"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
) )
func TestConfig(t *testing.T) { func TestHost(t *testing.T) {
Debug = false // Reset whatever was loaded in init() cases := map[string]struct {
t.Setenv("OLLAMA_DEBUG", "") value string
LoadConfig() expect string
require.False(t, Debug) }{
t.Setenv("OLLAMA_DEBUG", "false") "empty": {"", "127.0.0.1:11434"},
LoadConfig() "only address": {"1.2.3.4", "1.2.3.4:11434"},
require.False(t, Debug) "only port": {":1234", ":1234"},
t.Setenv("OLLAMA_DEBUG", "1") "address and port": {"1.2.3.4:1234", "1.2.3.4:1234"},
LoadConfig() "hostname": {"example.com", "example.com:11434"},
require.True(t, Debug) "hostname and port": {"example.com:1234", "example.com:1234"},
t.Setenv("OLLAMA_FLASH_ATTENTION", "1") "zero port": {":0", ":0"},
LoadConfig() "too large port": {":66000", ":11434"},
require.True(t, FlashAttention) "too small port": {":-1", ":11434"},
t.Setenv("OLLAMA_KEEP_ALIVE", "") "ipv6 localhost": {"[::1]", "[::1]:11434"},
LoadConfig() "ipv6 world open": {"[::]", "[::]:11434"},
require.Equal(t, 5*time.Minute, KeepAlive) "ipv6 no brackets": {"::1", "[::1]:11434"},
t.Setenv("OLLAMA_KEEP_ALIVE", "3") "ipv6 + port": {"[::1]:1337", "[::1]:1337"},
LoadConfig() "extra space": {" 1.2.3.4 ", "1.2.3.4:11434"},
require.Equal(t, 3*time.Second, KeepAlive) "extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"},
t.Setenv("OLLAMA_KEEP_ALIVE", "1h") "extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"},
LoadConfig() "extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"},
require.Equal(t, 1*time.Hour, KeepAlive) "http": {"http://1.2.3.4", "1.2.3.4:80"},
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s") "http port": {"http://1.2.3.4:4321", "1.2.3.4:4321"},
LoadConfig() "https": {"https://1.2.3.4", "1.2.3.4:443"},
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) "https port": {"https://1.2.3.4:4321", "1.2.3.4:4321"},
t.Setenv("OLLAMA_KEEP_ALIVE", "-1") }
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive) for name, tt := range cases {
t.Run(name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.value)
if host := Host(); host.Host != tt.expect {
t.Errorf("%s: expected %s, got %s", name, tt.expect, host.Host)
}
})
}
} }
func TestClientFromEnvironment(t *testing.T) { func TestOrigins(t *testing.T) {
type testCase struct { cases := []struct {
value string value string
expect string expect []string
err error }{
{"", []string{
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://10.0.0.1", []string{
"http://10.0.0.1",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://172.16.0.1,https://192.168.0.1", []string{
"http://172.16.0.1",
"https://192.168.0.1",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
{"http://totally.safe,http://definitely.legit", []string{
"http://totally.safe",
"http://definitely.legit",
"http://localhost",
"https://localhost",
"http://localhost:*",
"https://localhost:*",
"http://127.0.0.1",
"https://127.0.0.1",
"http://127.0.0.1:*",
"https://127.0.0.1:*",
"http://0.0.0.0",
"https://0.0.0.0",
"http://0.0.0.0:*",
"https://0.0.0.0:*",
"app://*",
"file://*",
"tauri://*",
}},
} }
for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value)
hostTestCases := map[string]*testCase{ if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
"empty": {value: "", expect: "127.0.0.1:11434"}, t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, }
"only port": {value: ":1234", expect: ":1234"}, })
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, }
"hostname": {value: "example.com", expect: "example.com:11434"}, }
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
"zero port": {value: ":0", expect: ":0"}, func TestBool(t *testing.T) {
"too large port": {value: ":66000", err: ErrInvalidHostPort}, cases := map[string]bool{
"too small port": {value: ":-1", err: ErrInvalidHostPort}, "": false,
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, "true": true,
"ipv6 world open": {value: "[::]", expect: "[::]:11434"}, "false": false,
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, "1": true,
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, "0": false,
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, // invalid values
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, "random": true,
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, "something": true,
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
} }
for k, v := range hostTestCases { for k, v := range cases {
t.Run(k, func(t *testing.T) { t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value) t.Setenv("OLLAMA_BOOL", k)
LoadConfig() if b := Bool("OLLAMA_BOOL")(); b != v {
t.Errorf("%s: expected %t, got %t", k, v, b)
}
})
}
}
oh, err := getOllamaHost() func TestUint(t *testing.T) {
if err != v.err { cases := map[string]uint{
t.Fatalf("expected %s, got %s", v.err, err) "0": 0,
"1": 1,
"1337": 1337,
// default values
"": 11434,
"-1": 11434,
"0o10": 11434,
"0x10": 11434,
"string": 11434,
} }
if err == nil { for k, v := range cases {
host := net.JoinHostPort(oh.Host, oh.Port) t.Run(k, func(t *testing.T) {
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) t.Setenv("OLLAMA_UINT", k)
if i := Uint("OLLAMA_UINT", 11434)(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}
func TestKeepAlive(t *testing.T) {
cases := map[string]time.Duration{
"": 5 * time.Minute,
"1s": time.Second,
"1m": time.Minute,
"1h": time.Hour,
"5m0s": 5 * time.Minute,
"1h2m3s": 1*time.Hour + 2*time.Minute + 3*time.Second,
"0": time.Duration(0),
"60": 60 * time.Second,
"120": 2 * time.Minute,
"3600": time.Hour,
"-0": time.Duration(0),
"-1": time.Duration(math.MaxInt64),
"-1m": time.Duration(math.MaxInt64),
// invalid values
" ": 5 * time.Minute,
"???": 5 * time.Minute,
"1d": 5 * time.Minute,
"1y": 5 * time.Minute,
"1w": 5 * time.Minute,
}
for tt, expect := range cases {
t.Run(tt, func(t *testing.T) {
t.Setenv("OLLAMA_KEEP_ALIVE", tt)
if actual := KeepAlive(); actual != expect {
t.Errorf("%s: expected %s, got %s", tt, expect, actual)
}
})
}
}
func TestVar(t *testing.T) {
cases := map[string]string{
"value": "value",
" value ": "value",
" 'value' ": "value",
` "value" `: "value",
" ' value ' ": " value ",
` " value " `: " value ",
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_VAR", k)
if s := Var("OLLAMA_VAR"); s != v {
t.Errorf("%s: expected %q, got %q", k, v, s)
} }
}) })
} }
......
...@@ -61,9 +61,9 @@ func AMDGetGPUInfo() []RocmGPUInfo { ...@@ -61,9 +61,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
var visibleDevices []string var visibleDevices []string
hipVD := envconfig.HipVisibleDevices // zero based index only hipVD := envconfig.HipVisibleDevices() // zero based index only
rocrVD := envconfig.RocrVisibleDevices // zero based index or UUID, but consumer cards seem to not support UUID rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID, but consumer cards seem to not support UUID
gpuDO := envconfig.GpuDeviceOrdinal // zero based index gpuDO := envconfig.GpuDeviceOrdinal() // zero based index
switch { switch {
// TODO is this priorty order right? // TODO is this priorty order right?
case hipVD != "": case hipVD != "":
...@@ -76,7 +76,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { ...@@ -76,7 +76,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
visibleDevices = strings.Split(gpuDO, ",") visibleDevices = strings.Split(gpuDO, ",")
} }
gfxOverride := envconfig.HsaOverrideGfxVersion gfxOverride := envconfig.HsaOverrideGfxVersion()
var supported []string var supported []string
libDir := "" libDir := ""
......
...@@ -53,7 +53,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { ...@@ -53,7 +53,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
} }
var supported []string var supported []string
gfxOverride := envconfig.HsaOverrideGfxVersion gfxOverride := envconfig.HsaOverrideGfxVersion()
if gfxOverride == "" { if gfxOverride == "" {
supported, err = GetSupportedGFX(libDir) supported, err = GetSupportedGFX(libDir)
if err != nil { if err != nil {
......
...@@ -26,7 +26,7 @@ func PayloadsDir() (string, error) { ...@@ -26,7 +26,7 @@ func PayloadsDir() (string, error) {
defer lock.Unlock() defer lock.Unlock()
var err error var err error
if payloadsDir == "" { if payloadsDir == "" {
runnersDir := envconfig.RunnersDir runnersDir := envconfig.RunnersDir()
if runnersDir != "" { if runnersDir != "" {
payloadsDir = runnersDir payloadsDir = runnersDir
...@@ -35,7 +35,7 @@ func PayloadsDir() (string, error) { ...@@ -35,7 +35,7 @@ func PayloadsDir() (string, error) {
// The remainder only applies on non-windows where we still carry payloads in the main executable // The remainder only applies on non-windows where we still carry payloads in the main executable
cleanupTmpDirs() cleanupTmpDirs()
tmpDir := envconfig.TmpDir tmpDir := envconfig.TmpDir()
if tmpDir == "" { if tmpDir == "" {
tmpDir, err = os.MkdirTemp("", "ollama") tmpDir, err = os.MkdirTemp("", "ollama")
if err != nil { if err != nil {
...@@ -105,7 +105,7 @@ func cleanupTmpDirs() { ...@@ -105,7 +105,7 @@ func cleanupTmpDirs() {
func Cleanup() { func Cleanup() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
runnersDir := envconfig.RunnersDir runnersDir := envconfig.RunnersDir()
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" { if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
// We want to fully clean up the tmpdir parent of the payloads dir // We want to fully clean up the tmpdir parent of the payloads dir
tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
......
...@@ -230,8 +230,8 @@ func GetGPUInfo() GpuInfoList { ...@@ -230,8 +230,8 @@ func GetGPUInfo() GpuInfoList {
// On windows we bundle the nvidia library one level above the runner dir // On windows we bundle the nvidia library one level above the runner dir
depPath := "" depPath := ""
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "cuda") depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "cuda")
} }
// Load ALL libraries // Load ALL libraries
...@@ -302,12 +302,12 @@ func GetGPUInfo() GpuInfoList { ...@@ -302,12 +302,12 @@ func GetGPUInfo() GpuInfoList {
} }
// Intel // Intel
if envconfig.IntelGpu { if envconfig.IntelGPU() {
oHandles = initOneAPIHandles() oHandles = initOneAPIHandles()
// On windows we bundle the oneapi library one level above the runner dir // On windows we bundle the oneapi library one level above the runner dir
depPath = "" depPath = ""
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir), "oneapi") depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
} }
for d := range oHandles.oneapi.num_drivers { for d := range oHandles.oneapi.num_drivers {
...@@ -611,7 +611,7 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { ...@@ -611,7 +611,7 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
} }
func getVerboseState() C.uint16_t { func getVerboseState() C.uint16_t {
if envconfig.Debug { if envconfig.Debug() {
return C.uint16_t(1) return C.uint16_t(1)
} }
return C.uint16_t(0) return C.uint16_t(0)
......
...@@ -45,14 +45,7 @@ func TestUnicodeModelDir(t *testing.T) { ...@@ -45,14 +45,7 @@ func TestUnicodeModelDir(t *testing.T) {
defer os.RemoveAll(modelDir) defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir) slog.Info("unicode", "OLLAMA_MODELS", modelDir)
oldModelsDir := os.Getenv("OLLAMA_MODELS") t.Setenv("OLLAMA_MODELS", modelDir)
if oldModelsDir == "" {
defer os.Unsetenv("OLLAMA_MODELS")
} else {
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
}
err = os.Setenv("OLLAMA_MODELS", modelDir)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
......
...@@ -5,14 +5,16 @@ package integration ...@@ -5,14 +5,16 @@ package integration
import ( import (
"context" "context"
"log/slog" "log/slog"
"os"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
) )
func TestMultiModelConcurrency(t *testing.T) { func TestMultiModelConcurrency(t *testing.T) {
...@@ -106,13 +108,16 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { ...@@ -106,13 +108,16 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) { func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
if vram == "" { if s == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
} }
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err) maxVram, err := strconv.ParseUint(s, 10, 64)
const MB = uint64(1024 * 1024) if err != nil {
t.Fatal(err)
}
type model struct { type model struct {
name string name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
...@@ -121,83 +126,82 @@ func TestMultiModelStress(t *testing.T) { ...@@ -121,83 +126,82 @@ func TestMultiModelStress(t *testing.T) {
smallModels := []model{ smallModels := []model{
{ {
name: "orca-mini", name: "orca-mini",
size: 2992 * MB, size: 2992 * format.MebiByte,
}, },
{ {
name: "phi", name: "phi",
size: 2616 * MB, size: 2616 * format.MebiByte,
}, },
{ {
name: "gemma:2b", name: "gemma:2b",
size: 2364 * MB, size: 2364 * format.MebiByte,
}, },
{ {
name: "stable-code:3b", name: "stable-code:3b",
size: 2608 * MB, size: 2608 * format.MebiByte,
}, },
{ {
name: "starcoder2:3b", name: "starcoder2:3b",
size: 2166 * MB, size: 2166 * format.MebiByte,
}, },
} }
mediumModels := []model{ mediumModels := []model{
{ {
name: "llama2", name: "llama2",
size: 5118 * MB, size: 5118 * format.MebiByte,
}, },
{ {
name: "mistral", name: "mistral",
size: 4620 * MB, size: 4620 * format.MebiByte,
}, },
{ {
name: "orca-mini:7b", name: "orca-mini:7b",
size: 5118 * MB, size: 5118 * format.MebiByte,
}, },
{ {
name: "dolphin-mistral", name: "dolphin-mistral",
size: 4620 * MB, size: 4620 * format.MebiByte,
}, },
{ {
name: "gemma:7b", name: "gemma:7b",
size: 5000 * MB, size: 5000 * format.MebiByte,
},
{
name: "codellama:7b",
size: 5118 * format.MebiByte,
}, },
// TODO - uncomment this once #3565 is merged and this is rebased on it
// {
// name: "codellama:7b",
// size: 5118 * MB,
// },
} }
// These seem to be too slow to be useful... // These seem to be too slow to be useful...
// largeModels := []model{ // largeModels := []model{
// { // {
// name: "llama2:13b", // name: "llama2:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "codellama:13b", // name: "codellama:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "orca-mini:13b", // name: "orca-mini:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "gemma:7b", // name: "gemma:7b",
// size: 5000 * MB, // size: 5000 * format.MebiByte,
// }, // },
// { // {
// name: "starcoder2:15b", // name: "starcoder2:15b",
// size: 9100 * MB, // size: 9100 * format.MebiByte,
// }, // },
// } // }
var chosenModels []model var chosenModels []model
switch { switch {
case max < 10000*MB: case maxVram < 10000*format.MebiByte:
slog.Info("selecting small models") slog.Info("selecting small models")
chosenModels = smallModels chosenModels = smallModels
// case max < 30000*MB: // case maxVram < 30000*format.MebiByte:
default: default:
slog.Info("selecting medium models") slog.Info("selecting medium models")
chosenModels = mediumModels chosenModels = mediumModels
...@@ -226,15 +230,15 @@ func TestMultiModelStress(t *testing.T) { ...@@ -226,15 +230,15 @@ func TestMultiModelStress(t *testing.T) {
} }
var wg sync.WaitGroup var wg sync.WaitGroup
consumed := uint64(256 * MB) // Assume some baseline usage consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
for i := 0; i < len(req); i++ { for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
if i > 1 && consumed > max { if i > 1 && consumed > vram {
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed))
break break
} }
consumed += chosenModels[i].size consumed += chosenModels[i].size
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) slog.Info("target vram", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed))
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {
......
...@@ -5,7 +5,6 @@ package integration ...@@ -5,7 +5,6 @@ package integration
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"os" "os"
"strconv" "strconv"
...@@ -14,8 +13,10 @@ import ( ...@@ -14,8 +13,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
) )
func TestMaxQueue(t *testing.T) { func TestMaxQueue(t *testing.T) {
...@@ -27,13 +28,10 @@ func TestMaxQueue(t *testing.T) { ...@@ -27,13 +28,10 @@ func TestMaxQueue(t *testing.T) {
// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU // Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits // Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
threadCount := 32 threadCount := 32
mq := os.Getenv("OLLAMA_MAX_QUEUE") if maxQueue := envconfig.MaxQueue(); maxQueue != 0 {
if mq != "" { threadCount = maxQueue
var err error
threadCount, err = strconv.Atoi(mq)
require.NoError(t, err)
} else { } else {
os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount)) t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
} }
req := api.GenerateRequest{ req := api.GenerateRequest{
......
...@@ -8,14 +8,14 @@ import ( ...@@ -8,14 +8,14 @@ import (
"testing" "testing"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestEstimateGPULayers(t *testing.T) { func TestEstimateGPULayers(t *testing.T) {
envconfig.Debug = true t.Setenv("OLLAMA_DEBUG", "1")
modelName := "dummy" modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err) require.NoError(t, err)
......
...@@ -163,7 +163,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -163,7 +163,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} else { } else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
} }
demandLib := envconfig.LLMLibrary demandLib := envconfig.LLMLibrary()
if demandLib != "" { if demandLib != "" {
serverPath := availableServers[demandLib] serverPath := availableServers[demandLib]
if serverPath == "" { if serverPath == "" {
...@@ -195,7 +195,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -195,7 +195,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
} }
if envconfig.Debug { if envconfig.Debug() {
params = append(params, "--verbose") params = append(params, "--verbose")
} }
...@@ -221,7 +221,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -221,7 +221,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--memory-f32") params = append(params, "--memory-f32")
} }
flashAttnEnabled := envconfig.FlashAttention flashAttnEnabled := envconfig.FlashAttention()
for _, g := range gpus { for _, g := range gpus {
// only cuda (compute capability 7+) and metal support flash attention // only cuda (compute capability 7+) and metal support flash attention
...@@ -382,7 +382,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr ...@@ -382,7 +382,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
slog.Info("starting llama server", "cmd", s.cmd.String()) slog.Info("starting llama server", "cmd", s.cmd.String())
if envconfig.Debug { if envconfig.Debug() {
filteredEnv := []string{} filteredEnv := []string{}
for _, ev := range s.cmd.Env { for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") || if strings.HasPrefix(ev, "CUDA_") ||
......
...@@ -646,7 +646,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio ...@@ -646,7 +646,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err return err
} }
if !envconfig.NoPrune && old != nil { if !envconfig.NoPrune() && old != nil {
if err := old.RemoveLayers(); err != nil { if err := old.RemoveLayers(); err != nil {
return err return err
} }
...@@ -885,7 +885,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu ...@@ -885,7 +885,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
// build deleteMap to prune unused layers // build deleteMap to prune unused layers
deleteMap := make(map[string]struct{}) deleteMap := make(map[string]struct{})
if !envconfig.NoPrune { if !envconfig.NoPrune() {
manifest, _, err = GetManifest(mp) manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"slices" "slices"
"testing" "testing"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
...@@ -108,7 +107,6 @@ func TestManifests(t *testing.T) { ...@@ -108,7 +107,6 @@ func TestManifests(t *testing.T) {
t.Run(n, func(t *testing.T) { t.Run(n, func(t *testing.T) {
d := t.TempDir() d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d) t.Setenv("OLLAMA_MODELS", d)
envconfig.LoadConfig()
for _, p := range wants.ps { for _, p := range wants.ps {
createManifest(t, d, p) createManifest(t, d, p)
......
...@@ -105,9 +105,7 @@ func (mp ModelPath) GetShortTagname() string { ...@@ -105,9 +105,7 @@ func (mp ModelPath) GetShortTagname() string {
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) { func (mp ModelPath) GetManifestPath() (string, error) {
dir := envconfig.ModelsDir return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
} }
func (mp ModelPath) BaseURL() *url.URL { func (mp ModelPath) BaseURL() *url.URL {
...@@ -118,9 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL { ...@@ -118,9 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL {
} }
func GetManifestPath() (string, error) { func GetManifestPath() (string, error) {
dir := envconfig.ModelsDir path := filepath.Join(envconfig.Models(), "manifests")
path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil { if err := os.MkdirAll(path, 0o755); err != nil {
return "", err return "", err
} }
...@@ -129,8 +125,6 @@ func GetManifestPath() (string, error) { ...@@ -129,8 +125,6 @@ func GetManifestPath() (string, error) {
} }
func GetBlobsPath(digest string) (string, error) { func GetBlobsPath(digest string) (string, error) {
dir := envconfig.ModelsDir
// only accept actual sha256 digests // only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$" pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern) re := regexp.MustCompile(pattern)
...@@ -140,7 +134,7 @@ func GetBlobsPath(digest string) (string, error) { ...@@ -140,7 +134,7 @@ func GetBlobsPath(digest string) (string, error) {
} }
digest = strings.ReplaceAll(digest, ":", "-") digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(dir, "blobs", digest) path := filepath.Join(envconfig.Models(), "blobs", digest)
dirPath := filepath.Dir(path) dirPath := filepath.Dir(path)
if digest == "" { if digest == "" {
dirPath = path dirPath = path
......
...@@ -7,8 +7,6 @@ import ( ...@@ -7,8 +7,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/envconfig"
) )
func TestGetBlobsPath(t *testing.T) { func TestGetBlobsPath(t *testing.T) {
...@@ -63,7 +61,6 @@ func TestGetBlobsPath(t *testing.T) { ...@@ -63,7 +61,6 @@ func TestGetBlobsPath(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", dir) t.Setenv("OLLAMA_MODELS", dir)
envconfig.LoadConfig()
got, err := GetBlobsPath(tc.digest) got, err := GetBlobsPath(tc.digest)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment