Unverified Commit 85ccf735 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

gptoss: enable flash attention by default (#11996)

parent 30fb7e19
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"slices" "slices"
"strings" "strings"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil" "github.com/ollama/ollama/fs/util/bufioutil"
) )
...@@ -479,7 +480,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { ...@@ -479,7 +480,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil }, nil
} }
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel) context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
...@@ -677,7 +678,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri ...@@ -677,7 +678,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
kv[i] *= context kv[i] *= context
} }
} }
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6 partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention {
// rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
}
} }
return return
...@@ -773,6 +779,13 @@ func (f GGML) SupportsFlashAttention() bool { ...@@ -773,6 +779,13 @@ func (f GGML) SupportsFlashAttention() bool {
return headCountK != 0 && headCountV != 0 && headCountK == headCountV return headCountK != 0 && headCountV != 0 && headCountK == headCountV
} }
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"gptoss", "gpt-oss",
}, f.KV().String("general.architecture"))
}
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type // kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 { func kvCacheBytesPerElement(cacheType string) float64 {
switch cacheType { switch cacheType {
......
...@@ -195,17 +195,19 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin ...@@ -195,17 +195,19 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
slog.Warn("model missing blk.0 layer size") slog.Warn("model missing blk.0 layer size")
} }
var kvct string useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
if envconfig.FlashAttention() &&
discover.GetGPUInfo().FlashAttentionSupported() && discover.GetGPUInfo().FlashAttentionSupported() &&
f.SupportsFlashAttention() { f.SupportsFlashAttention()
var kvct string
if useFlashAttention {
requested := strings.ToLower(envconfig.KvCacheType()) requested := strings.ToLower(envconfig.KvCacheType())
if requested != "" && f.SupportsKVCacheType(requested) { if requested != "" && f.SupportsKVCacheType(requested) {
kvct = requested kvct = requested
} }
} }
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct) kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
if len(kv) > 0 { if len(kv) > 0 {
layerSize += kv[0] layerSize += kv[0]
......
...@@ -195,6 +195,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a ...@@ -195,6 +195,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset // This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
// that can handle it. // that can handle it.
fa := envconfig.FlashAttention() fa := envconfig.FlashAttention()
if f.FlashAttention() {
slog.Info("model wants flash attention")
fa = true
}
if fa && !gpus.FlashAttentionSupported() { if fa && !gpus.FlashAttentionSupported() {
slog.Warn("flash attention enabled but not supported by gpu") slog.Warn("flash attention enabled but not supported by gpu")
fa = false fa = false
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment