Commit c116a752 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

kvcache: Don't shift empty batches

When we context shift, we delete half the context and apply RoPE
with an offset to the other half. We used to RoPE across the entire
context in a single pass with a zero offset for the deleted
section. With the change to shifting in batches, we can skip any
batches where all of the offsets would be zero. This typically
reduces the number of operations by half.
parent 3515cc37
...@@ -646,18 +646,31 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { ...@@ -646,18 +646,31 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
seqRange := c.cellRanges[seq] seqRange := c.cellRanges[seq]
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
ctx := c.backend.NewContext()
size := min(seqRange.max-start+1, c.maxBatch) size := min(seqRange.max-start+1, c.maxBatch)
offsets := make([]int32, size) offsets := make([]int32, size)
var batchFirst, batchLast int
batchFirst = -1
for i := range offsets { for i := range offsets {
cell := c.cells[start+i] cell := c.cells[start+i]
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
offsets[i] = offset offsets[i] = offset
if batchFirst < 0 {
batchFirst = i
}
batchLast = i
} }
} }
if batchFirst < 0 {
continue
}
offsets = offsets[batchFirst : batchLast+1]
ctx := c.backend.NewContext()
kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
for i, key := range c.keys { for i, key := range c.keys {
...@@ -669,10 +682,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { ...@@ -669,10 +682,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
numKVHeads := key.Dim(1) numKVHeads := key.Dim(1)
rowSize := key.Stride(2) rowSize := key.Stride(2)
key = key.View(ctx, rowSize*start, key = key.View(ctx, rowSize*(start+batchFirst),
kHeadDim, key.Stride(1), kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2), numKVHeads, key.Stride(2),
size, len(offsets),
) )
roped, err := c.shiftFn(ctx, i, key, kShift) roped, err := c.shiftFn(ctx, i, key, kShift)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment