Commit 3ed7ad3a authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

kvcache: Pass granular cache size into implementations

Currently the runner computes the kv size needed and creates a
cache of that size. This is the context size times number of
parallel sequences.

Cache implementations can make better decisions about their memory
usage, so instead pass in the required capacity, number of sequences
and maximum batch size. For now, the causal cache just uses this to
compute the size in the same way as before.
parent 6d110304
...@@ -43,8 +43,13 @@ type Cache interface { ...@@ -43,8 +43,13 @@ type Cache interface {
// ** cache management ** // ** cache management **
// Init sets up runtime parameters // Init sets up runtime parameters.
Init(backend ml.Backend, dtype ml.DType, capacity int32) // backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it // Close closes the cache and frees resources associated with it
Close() Close()
......
...@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e ...@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The mask is of shape history size, batch size // The mask is of shape history size, batch size
type Causal struct { type Causal struct {
DType ml.DType DType ml.DType
Capacity int32
windowSize int32 windowSize int32
opts CausalOptions opts CausalOptions
...@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal { ...@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
} }
} }
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil { if c.config == nil {
var config ml.CacheConfig var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok { if cc, ok := backend.(ml.BackendCacheConfig); ok {
...@@ -119,9 +118,11 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { ...@@ -119,9 +118,11 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.config.MaskDType = ml.DTypeF32 c.config.MaskDType = ml.DTypeF32
} }
cacheSize := maxSequences * capacity
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype c.DType = dtype
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange) c.cellRanges = make(map[int]cellRange)
c.backend = backend c.backend = backend
} }
...@@ -210,7 +211,7 @@ func (c *Causal) findStartLoc() (int, error) { ...@@ -210,7 +211,7 @@ func (c *Causal) findStartLoc() (int, error) {
} }
} }
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
} }
func roundDown(length, pad int) int { func roundDown(length, pad int) int {
...@@ -265,7 +266,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { ...@@ -265,7 +266,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
return maskTensor, nil return maskTensor, nil
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys { for i, key := range c.keys {
if key == nil { if key == nil {
continue continue
...@@ -275,8 +276,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { ...@@ -275,8 +276,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
numKVHeads := key.Dim(1) numKVHeads := key.Dim(1)
rowSize := key.Stride(2) rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i] value := c.values[i]
var vSrcView, vDstView ml.Tensor var vSrcView, vDstView ml.Tensor
...@@ -284,14 +285,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { ...@@ -284,14 +285,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
vHeadDim := value.Dim(1) vHeadDim := value.Dim(1)
elemSize := value.Stride(0) elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else { } else {
vHeadDim := value.Dim(0) vHeadDim := value.Dim(0)
rowSize := value.Stride(2) rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
} }
ctx.Forward( ctx.Forward(
...@@ -480,14 +481,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { ...@@ -480,14 +481,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
} }
if _, ok := c.keys[c.curLayer]; !ok { if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
} }
if _, ok := c.values[c.curLayer]; !ok { if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV { if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
} else { } else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
} }
} }
...@@ -498,7 +499,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { ...@@ -498,7 +499,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
elemSize := c.values[c.curLayer].Stride(0) elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3) value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
} else { } else {
rowSize := c.values[c.curLayer].Stride(2) rowSize := c.values[c.curLayer].Stride(2)
......
...@@ -25,7 +25,7 @@ func TestStore(t *testing.T) { ...@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
...@@ -58,7 +58,7 @@ func TestSWA(t *testing.T) { ...@@ -58,7 +58,7 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil) cache := NewSWACache(1, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF32, 16) cache.Init(backend, ml.DTypeF32, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
...@@ -81,7 +81,7 @@ func TestSequences(t *testing.T) { ...@@ -81,7 +81,7 @@ func TestSequences(t *testing.T) {
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
...@@ -116,7 +116,7 @@ func TestRemove(t *testing.T) { ...@@ -116,7 +116,7 @@ func TestRemove(t *testing.T) {
}) })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
...@@ -181,7 +181,7 @@ func TestDefrag(t *testing.T) { ...@@ -181,7 +181,7 @@ func TestDefrag(t *testing.T) {
}) })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
...@@ -229,7 +229,7 @@ func TestCopy(t *testing.T) { ...@@ -229,7 +229,7 @@ func TestCopy(t *testing.T) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{ tests := []testCase{
{ {
......
...@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache { ...@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
} }
} }
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil { if c.config == nil {
var config ml.CacheConfig var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok { if cc, ok := backend.(ml.BackendCacheConfig); ok {
...@@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) ...@@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
c.config = &config c.config = &config
} }
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 { if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
} }
......
...@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache { ...@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
} }
} }
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
for _, cache := range c.caches { for _, cache := range c.caches {
cache.Init(backend, dtype, capacity) cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
} }
} }
......
...@@ -31,8 +31,10 @@ type InputCache struct { ...@@ -31,8 +31,10 @@ type InputCache struct {
cache kvcache.Cache cache kvcache.Cache
} }
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) { func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
if kvSize/int32(numSlots) < 1 { numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
} }
...@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots ...@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
cache := model.Config().Cache cache := model.Config().Cache
if cache != nil { if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize) cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
} }
return &InputCache{ return &InputCache{
numCtx: kvSize / int32(numSlots), numCtx: numCtx,
enabled: cache != nil, enabled: cache != nil,
slots: slots, slots: slots,
multiUserCache: multiUserCache, multiUserCache: multiUserCache,
......
...@@ -699,7 +699,7 @@ func (s *Server) loadModel( ...@@ -699,7 +699,7 @@ func (s *Server) loadModel(
panic("loras are not yet implemented") panic("loras are not yet implemented")
} }
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache) s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment