You need to sign in or sign up before continuing.
Commit 7004fab5 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update

parent 42dd5af5
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
func TestEstimateGPULayers(t *testing.T) { func TestEstimateGPULayers(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1") t.Setenv("OLLAMA_DEBUG", "1")
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
modelName := "dummy" modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err) require.NoError(t, err)
......
...@@ -214,14 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter ...@@ -214,14 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
params = append(params, "--threads", strconv.Itoa(defaultThreads)) params = append(params, "--threads", strconv.Itoa(defaultThreads))
} }
flashAttnEnabled := envconfig.FlashAttention() fa := envconfig.FlashAttention()
if fa && !gpus.FlashAttentionSupported() {
slog.Warn("flash attention enabled but not supported by gpu")
fa = false
}
for _, g := range gpus { if fa && !ggml.SupportsFlashAttention() {
// only cuda (compute capability 7+) and metal support flash attention slog.Warn("flash attention enabled but not supported by model")
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) { fa = false
flashAttnEnabled = false }
kvct := strings.ToLower(envconfig.KvCacheType())
if fa {
slog.Info("enabling flash attention")
params = append(params, "--flash-attn")
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
params = append(params, "--kv-cache-type", kvct)
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
} }
// mmap has issues with partial offloading on metal
for _, g := range gpus {
// mmap has issues with partial offloading on metal // mmap has issues with partial offloading on metal
if g.Library == "metal" && if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 && uint64(opts.NumGPU) > 0 &&
...@@ -231,9 +253,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter ...@@ -231,9 +253,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
} }
} }
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
// Windows CUDA should not use mmap for best performance // Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing // Linux with a model larger than free space, mmap leads to thrashing
...@@ -617,27 +636,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { ...@@ -617,27 +636,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
const jsonGrammar = ` const jsonGrammar = `
root ::= object root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::= object ::=
"{" ws ( "{" ws (
string ":" ws value string ":" ws value
("," ws string ":" ws value)* ("," ws string ":" ws value)*
)? "}" ws )? "}" ws
array ::= array ::=
"[" ws ( "[" ws (
value value
("," ws value)* ("," ws value)*
)? "]" ws )? "]" ws
string ::= string ::=
"\"" ( "\"" (
[^"\\\x7F\x00-\x1F] | [^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws )* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed # Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)? ws ::= ([ \t\n] ws)?
` `
...@@ -667,7 +681,7 @@ type completion struct { ...@@ -667,7 +681,7 @@ type completion struct {
type CompletionRequest struct { type CompletionRequest struct {
Prompt string Prompt string
Format string Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
} }
...@@ -732,10 +746,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu ...@@ -732,10 +746,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("unexpected server status: %s", status.ToString()) return fmt.Errorf("unexpected server status: %s", status.ToString())
} }
if req.Format == "json" { // TODO (parthsareen): Move conversion to grammar with sampling logic
request["grammar"] = jsonGrammar // API should do error handling for invalid formats
if !strings.Contains(strings.ToLower(req.Prompt), "json") { if req.Format != nil && strings.TrimSpace(string(req.Format)) != "null" {
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.") if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
request["grammar"] = jsonGrammar
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
}
} else if schema, err := func() (llama.JsonSchema, error) {
var schema llama.JsonSchema
err := json.Unmarshal(req.Format, &schema)
return schema, err
}(); err == nil {
request["grammar"] = schema.AsGrammar()
} else {
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
} }
} }
......
...@@ -62,7 +62,12 @@ type Usage struct { ...@@ -62,7 +62,12 @@ type Usage struct {
} }
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
}
type JsonSchema struct {
Schema map[string]any `json:"schema"`
} }
type EmbedRequest struct { type EmbedRequest struct {
...@@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
options["top_p"] = 1.0 options["top_p"] = 1.0
} }
var format string var format json.RawMessage
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { if r.ResponseFormat != nil {
format = "json" switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
// Support the old "json_object" type for OpenAI compatibility
case "json_object":
format = json.RawMessage(`"json"`)
case "json_schema":
if r.ResponseFormat.JsonSchema != nil {
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
if err != nil {
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
}
format = schema
}
}
} }
return &api.ChatRequest{ return &api.ChatRequest{
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
...@@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) { ...@@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) {
"presence_penalty": 5.0, "presence_penalty": 5.0,
"top_p": 6.0, "top_p": 6.0,
}, },
Format: "json", Format: json.RawMessage(`"json"`),
Stream: &True, Stream: &True,
}, },
}, },
...@@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) { ...@@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
return
} }
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatal("requests did not match") t.Fatalf("requests did not match: %+v", diff)
} }
if diff := cmp.Diff(tc.err, errResp); diff != "" {
if !reflect.DeepEqual(tc.err, errResp) { t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
t.Fatal("errors did not match")
} }
}) })
} }
...@@ -637,4 +638,4 @@ func TestRetrieveMiddleware(t *testing.T) { ...@@ -637,4 +638,4 @@ func TestRetrieveMiddleware(t *testing.T) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
} }
} }
} }
\ No newline at end of file
...@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
if req.Format != "" && req.Format != "json" { if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
return
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return return
} }
...@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var b bytes.Buffer var b bytes.Buffer
if req.Context != nil { if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
s, err := r.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
...@@ -1582,4 +1580,4 @@ func handleScheduleError(c *gin.Context, name string, err error) { ...@@ -1582,4 +1580,4 @@ func handleScheduleError(c *gin.Context, name string, err error) {
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment