Commit 7004fab5 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update

parent 42dd5af5
......@@ -15,7 +15,7 @@ import (
func TestEstimateGPULayers(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1")
t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16
modelName := "dummy"
f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err)
......
......@@ -214,14 +214,36 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
params = append(params, "--threads", strconv.Itoa(defaultThreads))
}
flashAttnEnabled := envconfig.FlashAttention()
fa := envconfig.FlashAttention()
if fa && !gpus.FlashAttentionSupported() {
slog.Warn("flash attention enabled but not supported by gpu")
fa = false
}
for _, g := range gpus {
// only cuda (compute capability 7+) and metal support flash attention
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
if fa && !ggml.SupportsFlashAttention() {
slog.Warn("flash attention enabled but not supported by model")
fa = false
}
kvct := strings.ToLower(envconfig.KvCacheType())
if fa {
slog.Info("enabling flash attention")
params = append(params, "--flash-attn")
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
params = append(params, "--kv-cache-type", kvct)
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
}
// mmap has issues with partial offloading on metal
for _, g := range gpus {
// mmap has issues with partial offloading on metal
if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 &&
......@@ -231,9 +253,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
// Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing
......@@ -617,27 +636,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
const jsonGrammar = `
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
`
......@@ -667,7 +681,7 @@ type completion struct {
type CompletionRequest struct {
Prompt string
Format string
Format json.RawMessage
Images []ImageData
Options *api.Options
}
......@@ -732,10 +746,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("unexpected server status: %s", status.ToString())
}
if req.Format == "json" {
request["grammar"] = jsonGrammar
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
// TODO (parthsareen): Move conversion to grammar with sampling logic
// API should do error handling for invalid formats
if req.Format != nil && strings.TrimSpace(string(req.Format)) != "null" {
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
request["grammar"] = jsonGrammar
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
}
} else if schema, err := func() (llama.JsonSchema, error) {
var schema llama.JsonSchema
err := json.Unmarshal(req.Format, &schema)
return schema, err
}(); err == nil {
request["grammar"] = schema.AsGrammar()
} else {
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
}
}
......
......@@ -62,7 +62,12 @@ type Usage struct {
}
type ResponseFormat struct {
Type string `json:"type"`
Type string `json:"type"`
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
}
type JsonSchema struct {
Schema map[string]any `json:"schema"`
}
type EmbedRequest struct {
......@@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
options["top_p"] = 1.0
}
var format string
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
format = "json"
var format json.RawMessage
if r.ResponseFormat != nil {
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
// Support the old "json_object" type for OpenAI compatibility
case "json_object":
format = json.RawMessage(`"json"`)
case "json_schema":
if r.ResponseFormat.JsonSchema != nil {
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
if err != nil {
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
}
format = schema
}
}
}
return &api.ChatRequest{
......
......@@ -13,6 +13,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
......@@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) {
"presence_penalty": 5.0,
"top_p": 6.0,
},
Format: "json",
Format: json.RawMessage(`"json"`),
Stream: &True,
},
},
......@@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
return
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match: %+v", diff)
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
}
})
}
......@@ -637,4 +638,4 @@ func TestRetrieveMiddleware(t *testing.T) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}
}
\ No newline at end of file
......@@ -148,10 +148,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
if req.Format != "" && req.Format != "json" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
return
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
......@@ -251,6 +248,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var b bytes.Buffer
if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
......@@ -1582,4 +1580,4 @@ func handleScheduleError(c *gin.Context, name string, err error) {
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment