Commit 2d6eac90 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

kvcache: Optimize sliding window attention

Currently sliding window attention allocates and uses the full
context size and just masks out any tokens that are outside of the
window. However, we really only need (roughly) the sliding window
size.

At large context sizes this improves two things:
 - Memory allocated - since the fully context size is allocated up front,
   memory requirements drop substantially. On Gemma3:4b with a 32k
   context window, total memory usage (including weights and non-sliding
   layers) drops from ~20GB to ~8GB.
 - Computation - ranges that are completely outside of the sliding
   window are now removed from the tensors that are returned from the
   cache rather than simply being masked out. This results in more
   efficient processing, scaling with the size of the context that
   has actually been used.

Notable, this does not update the scheduler for any model to be aware of
the smaller memory requirements. This is difficult for Gemma3 because
the layers are heterogeneous between sliding and non-sliding attention.
As a result, while actual memory consumption will be reduced, the
scheduler will over-estimate the requirements of the model. This means
that splitting between GPUs or GPUs and CPUs will still be suboptimal.

Bug #9730
parent 3ed7ad3a
...@@ -118,7 +118,12 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity ...@@ -118,7 +118,12 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.config.MaskDType = ml.DTypeF32 c.config.MaskDType = ml.DTypeF32
} }
cacheSize := maxSequences * capacity var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
cacheSize = maxSequences * capacity
} else {
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
}
cacheSize = roundUp(cacheSize, c.config.CachePadding) cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize) c.cells = make([]cacheCell, cacheSize)
...@@ -147,6 +152,8 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error { ...@@ -147,6 +152,8 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curPositions = batch.Positions c.curPositions = batch.Positions
c.opts.Except = nil c.opts.Except = nil
c.updateSlidingWindow()
var err error var err error
c.curLoc, err = c.findStartLoc() c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) { if errors.Is(err, ErrKvCacheFull) {
...@@ -214,6 +221,50 @@ func (c *Causal) findStartLoc() (int, error) { ...@@ -214,6 +221,50 @@ func (c *Causal) findStartLoc() (int, error) {
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells)) return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
} }
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
}
func roundDown(length, pad int) int { func roundDown(length, pad int) int {
return (length / pad) * pad return (length / pad) * pad
} }
......
...@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) { ...@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil) cache := NewSWACache(1, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF32, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
name: "SlidingWindow", name: "FirstBatch",
in: []float32{1, 2, 3, 4}, in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4}, inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0}, seqs: []int{0, 0, 0, 0},
...@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) { ...@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
expectedShape: []int{1, 1, 4}, expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
}, },
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
} }
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment