Commit 7004fab5 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update

parent 42dd5af5
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
func TestEstimateGPULayers(t *testing.T) { func TestEstimateGPULayers(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1") t.Setenv("OLLAMA_DEBUG", "1")
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
modelName := "dummy" modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err) require.NoError(t, err)
......
...@@ -214,14 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter ...@@ -214,14 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
params = append(params, "--threads", strconv.Itoa(defaultThreads)) params = append(params, "--threads", strconv.Itoa(defaultThreads))
} }
flashAttnEnabled := envconfig.FlashAttention() fa := envconfig.FlashAttention()
if fa && !gpus.FlashAttentionSupported() {
slog.Warn("flash attention enabled but not supported by gpu")
fa = false
}
for _, g := range gpus { if fa && !ggml.SupportsFlashAttention() {
// only cuda (compute capability 7+) and metal support flash attention slog.Warn("flash attention enabled but not supported by model")
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) { fa = false
flashAttnEnabled = false }
kvct := strings.ToLower(envconfig.KvCacheType())
if fa {
slog.Info("enabling flash attention")
params = append(params, "--flash-attn")
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
params = append(params, "--kv-cache-type", kvct)
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
} }
// mmap has issues with partial offloading on metal
for _, g := range gpus {
// mmap has issues with partial offloading on metal // mmap has issues with partial offloading on metal
if g.Library == "metal" && if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 && uint64(opts.NumGPU) > 0 &&
...@@ -231,9 +253,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter ...@@ -231,9 +253,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
} }
} }
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
// Windows CUDA should not use mmap for best performance // Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing // Linux with a model larger than free space, mmap leads to thrashing
...@@ -617,27 +636,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { ...@@ -617,27 +636,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
const jsonGrammar = ` const jsonGrammar = `
root ::= object root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::= object ::=
"{" ws ( "{" ws (
string ":" ws value string ":" ws value
("," ws string ":" ws value)* ("," ws string ":" ws value)*
)? "}" ws )? "}" ws
array ::= array ::=
"[" ws ( "[" ws (
value value
("," ws value)* ("," ws value)*
)? "]" ws )? "]" ws
string ::= string ::=
"\"" ( "\"" (
[^"\\\x7F\x00-\x1F] | [^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws )* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed # Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)? ws ::= ([ \t\n] ws)?
` `
...@@ -667,7 +681,7 @@ type completion struct { ...@@ -667,7 +681,7 @@ type completion struct {
type CompletionRequest struct { type CompletionRequest struct {
Prompt string Prompt string
Format string Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
} }
...@@ -732,10 +746,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu ...@@ -732,10 +746,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("unexpected server status: %s", status.ToString()) return fmt.Errorf("unexpected server status: %s", status.ToString())
} }
if req.Format == "json" { // TODO (parthsareen): Move conversion to grammar with sampling logic
// API should do error handling for invalid formats
if req.Format != nil && strings.TrimSpace(string(req.Format)) != "null" {
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
request["grammar"] = jsonGrammar request["grammar"] = jsonGrammar
if !strings.Contains(strings.ToLower(req.Prompt), "json") { if !strings.Contains(strings.ToLower(req.Prompt), "json") {
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.") slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
}
} else if schema, err := func() (llama.JsonSchema, error) {
var schema llama.JsonSchema
err := json.Unmarshal(req.Format, &schema)
return schema, err
}(); err == nil {
request["grammar"] = schema.AsGrammar()
} else {
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
} }
} }
......
...@@ -63,6 +63,11 @@ type Usage struct { ...@@ -63,6 +63,11 @@ type Usage struct {
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
}
type JsonSchema struct {
Schema map[string]any `json:"schema"`
} }
type EmbedRequest struct { type EmbedRequest struct {
...@@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { ...@@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
options["top_p"] = 1.0 options["top_p"] = 1.0
} }
var format string var format json.RawMessage
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { if r.ResponseFormat != nil {
format = "json" switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
// Support the old "json_object" type for OpenAI compatibility
case "json_object":
format = json.RawMessage(`"json"`)
case "json_schema":
if r.ResponseFormat.JsonSchema != nil {
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
if err != nil {
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
}
format = schema
}
}
} }
return &api.ChatRequest{ return &api.ChatRequest{
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
...@@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) { ...@@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) {
"presence_penalty": 5.0, "presence_penalty": 5.0,
"top_p": 6.0, "top_p": 6.0,
}, },
Format: "json", Format: json.RawMessage(`"json"`),
Stream: &True, Stream: &True,
}, },
}, },
...@@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) { ...@@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err) t.Fatal(err)
} }
return
} }
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatal("requests did not match") t.Fatalf("requests did not match: %+v", diff)
} }
if diff := cmp.Diff(tc.err, errResp); diff != "" {
if !reflect.DeepEqual(tc.err, errResp) { t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
t.Fatal("errors did not match")
} }
}) })
} }
......
...@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
if req.Format != "" && req.Format != "json" { if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
return
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return return
} }
...@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var b bytes.Buffer var b bytes.Buffer
if req.Context != nil { if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
s, err := r.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment