You need to sign in or sign up before continuing.
Commit 93492f1e authored by cmiller01's avatar cmiller01
Browse files

correct precedence of serve params (args over env over default)

parent fb593b7b
...@@ -513,28 +513,39 @@ func generateBatch(cmd *cobra.Command, model string) error { ...@@ -513,28 +513,39 @@ func generateBatch(cmd *cobra.Command, model string) error {
return nil return nil
} }
func RunServer(cmd *cobra.Command, _ []string) error { // getRunServerParams takes a command and the environment variables and returns the correct params
host, err := cmd.Flags().GetString("host") // given the order of precedence: command line args (highest), environment variables, defaults (lowest)
if err != nil { func getRunServerParams(cmd *cobra.Command) (host, port string, extraOrigins []string, err error) {
return errors.New("host unset") host = os.Getenv("OLLAMA_HOST")
} hostFlag := cmd.Flags().Lookup("host")
if os.Getenv("OLLAMA_HOST") != "" { if hostFlag == nil {
host = os.Getenv("OLLAMA_HOST") return "", "", nil, errors.New("host unset")
} }
port, err := cmd.Flags().GetString("port") if hostFlag.Changed || host == "" {
host = hostFlag.Value.String()
}
port = os.Getenv("OLLAMA_PORT")
portFlag := cmd.Flags().Lookup("port")
if portFlag == nil {
return "", "", nil, errors.New("port unset")
}
if portFlag.Changed || port == "" {
port = portFlag.Value.String()
}
extraOrigins, err = cmd.Flags().GetStringSlice("allowed-origins")
if err != nil { if err != nil {
return errors.New("port unset") return "", "", nil, err
}
if os.Getenv("OLLAMA_PORT") != "" {
port = os.Getenv("OLLAMA_PORT")
} }
return host, port, extraOrigins, nil
}
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port)) func RunServer(cmd *cobra.Command, _ []string) error {
host, port, extraOrigins, err := getRunServerParams(cmd)
if err != nil { if err != nil {
return err return err
} }
extraOrigins, err := cmd.Flags().GetStringSlice("allowed-origins")
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%s", host, port))
if err != nil { if err != nil {
return err return err
} }
......
package cmd
import (
"os"
"testing"
)
func TestGetRunServerParams(t *testing.T) {
t.Run("default values", func(t *testing.T) {
cmd := NewCLI()
serveCmd, _, err := cmd.Find([]string{"serve"})
if err != nil {
t.Errorf("expected serve command, got %s", err)
}
host, port, extraOrigins, err := getRunServerParams(serveCmd)
// assertions
if err != nil {
t.Errorf("unexpected error, got %s", err)
}
if host != "127.0.0.1" {
t.Errorf("unexpected host, got %s", host)
}
if port != "11434" {
t.Errorf("unexpected port, got %s", port)
}
if len(extraOrigins) != 0 {
t.Errorf("unexpected origins, got %s", extraOrigins)
}
})
t.Run("environment variables take precedence over default", func(t *testing.T) {
cmd := NewCLI()
serveCmd, _, err := cmd.Find([]string{"serve"})
if err != nil {
t.Errorf("expected serve command, got %s", err)
}
// setup environment variables
err = os.Setenv("OLLAMA_HOST", "0.0.0.0")
if err != nil {
t.Errorf("could not set env var")
}
err = os.Setenv("OLLAMA_PORT", "9999")
if err != nil {
t.Errorf("could not set env var")
}
defer func() {
os.Unsetenv("OLLAMA_HOST")
os.Unsetenv("OLLAMA_PORT")
}()
host, port, extraOrigins, err := getRunServerParams(serveCmd)
// assertions
if err != nil {
t.Errorf("unexpected error, got %s", err)
}
if host != "0.0.0.0" {
t.Errorf("unexpected host, got %s", host)
}
if port != "9999" {
t.Errorf("unexpected port, got %s", port)
}
if len(extraOrigins) != 0 {
t.Errorf("unexpected origins, got %s", extraOrigins)
}
})
t.Run("command line args take precedence over env vars", func(t *testing.T) {
cmd := NewCLI()
serveCmd, _, err := cmd.Find([]string{"serve"})
if err != nil {
t.Errorf("expected serve command, got %s", err)
}
// setup environment variables
err = os.Setenv("OLLAMA_HOST", "0.0.0.0")
if err != nil {
t.Errorf("could not set env var")
}
err = os.Setenv("OLLAMA_PORT", "9999")
if err != nil {
t.Errorf("could not set env var")
}
defer func() {
os.Unsetenv("OLLAMA_HOST")
os.Unsetenv("OLLAMA_PORT")
}()
// now set command flags
serveCmd.Flags().Set("host", "localhost")
serveCmd.Flags().Set("port", "8888")
serveCmd.Flags().Set("allowed-origins", "http://foo.example.com,http://192.168.1.1")
host, port, extraOrigins, err := getRunServerParams(serveCmd)
if err != nil {
t.Errorf("unexpected error, got %s", err)
}
if host != "localhost" {
t.Errorf("unexpected host, got %s", host)
}
if port != "8888" {
t.Errorf("unexpected port, got %s", port)
}
if len(extraOrigins) != 2 {
t.Errorf("expected two origins, got length %d", len(extraOrigins))
}
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment