causal.go 17 KB
Newer Older
Jesse Gross's avatar
Jesse Gross committed
1
2
3
4
5
6
7
8
9
package kvcache

import (
	"errors"
	"fmt"
	"math"
	"slices"

	"github.com/ollama/ollama/ml"
10
	"github.com/ollama/ollama/model/input"
Jesse Gross's avatar
Jesse Gross committed
11
12
13
14
15
16
17
18
19
20
)

type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)

// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
//
// The tensors are of shape embed dim, kv heads, batch size
// The mask is of shape history size, batch size
type Causal struct {
21
22
23
24
25
26
27
28
29
30
	DType ml.DType

	// swaWindowSize is the number of tokens that will be included in the mask
	// during attention operations. swaMemorySize is the number of tokens that
	// will be retained in memory for partial prefix caching. Set to math.MaxInt32
	// for unlimited or if sliding window attention is not being used.
	swaWindowSize int32
	swaMemorySize int32

	chunkSize int32
Jesse Gross's avatar
Jesse Gross committed
31

32
33
	opts CausalOptions

34
35
36
	// maxBatch is the largest batch that we might receive
	maxBatch int

37
38
39
	// config controls mostly backend-specific optimizations
	config *ml.CacheConfig

Jesse Gross's avatar
Jesse Gross committed
40
41
42
43
44
	// ** current forward pass **

	// size of the current batch
	curBatchSize int

45
46
47
	// locations for data storage for this batch
	curLoc ml.Tensor

Jesse Gross's avatar
Jesse Gross committed
48
49
50
	// mask of the cache as used by this batch
	curMask ml.Tensor

51
52
53
	// the active layer for Get and Put
	curLayer int

Jesse Gross's avatar
Jesse Gross committed
54
55
56
	// locations in the cache that are needed for this batch
	curCellRange cellRange

57
58
59
60
61
62
	// curSequences is the sequences corresponding to this pass's entries in the cache
	curSequences []int

	// curPositions is the positions corresponding to this pass's entries in the cache
	curPositions []int32

Jesse Gross's avatar
Jesse Gross committed
63
64
65
66
67
68
69
70
71
72
73
74
75
	// ** cache metadata **

	// for each possible location in the cache, stores the position and set of sequences
	// that reference the data there
	cells []cacheCell

	// maps from sequence to the range of locations where it is stored in the cache
	cellRanges map[int]cellRange

	// ** cache data storage **

	shiftFn      shiftFn
	backend      ml.Backend
76
77
	ctxs         map[int]ml.Context
	keys, values map[int]ml.Tensor
Jesse Gross's avatar
Jesse Gross committed
78
79
80
81
82
83
84
85
86
87
88
89
90
}

type cacheCell struct {
	pos       int32
	sequences []int
}

type cellRange struct {
	min int
	max int
}

func NewCausalCache(shift shiftFn) *Causal {
91
	return &Causal{
92
93
94
95
		shiftFn: shift,
		ctxs:    make(map[int]ml.Context),
		keys:    make(map[int]ml.Tensor),
		values:  make(map[int]ml.Tensor),
96
	}
Jesse Gross's avatar
Jesse Gross committed
97
98
99
}

func NewSWACache(windowSize int32, shift shiftFn) *Causal {
100
	return &Causal{
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
		swaWindowSize: windowSize,
		shiftFn:       shift,
		ctxs:          make(map[int]ml.Context),
		keys:          make(map[int]ml.Tensor),
		values:        make(map[int]ml.Tensor),
	}
}

func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
	return &Causal{
		swaWindowSize: windowSize,
		swaMemorySize: memorySize,
		shiftFn:       shift,
		ctxs:          make(map[int]ml.Context),
		keys:          make(map[int]ml.Tensor),
		values:        make(map[int]ml.Tensor),
117
	}
Jesse Gross's avatar
Jesse Gross committed
118
119
}

Michael Yang's avatar
Michael Yang committed
120
121
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
	return &Causal{
122
123
124
125
126
		chunkSize: chunkSize,
		shiftFn:   shift,
		ctxs:      make(map[int]ml.Context),
		keys:      make(map[int]ml.Tensor),
		values:    make(map[int]ml.Tensor),
Michael Yang's avatar
Michael Yang committed
127
128
129
	}
}

130
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
131
132
133
134
135
136
137
138
139
140
141
142
	if c.config == nil {
		var config ml.CacheConfig
		if cc, ok := backend.(ml.BackendCacheConfig); ok {
			config = cc.CacheConfig()
		}
		c.config = &config
	}

	if c.config.CachePadding == 0 {
		c.config.CachePadding = 1
	}

143
144
145
146
	if c.config.MaskDType == ml.DTypeOther {
		c.config.MaskDType = ml.DTypeF32
	}

147
148
149
150
151
152
	if c.swaWindowSize == 0 {
		c.swaWindowSize = math.MaxInt32
	}
	if c.swaMemorySize == 0 {
		c.swaMemorySize = c.swaWindowSize
	}
153
154
155
156
157
158
159
160
161
	// We will allocate space in the cache for the stop token, which won't be part of a follow on
	// sequence, so allocate an extra token of storage to ensure that we can jump back without
	// causing a cache break. As an optimization, only do this when we have parallel sequences
	// because the extra token will live in the batch buffer and won't get overwritten if we
	// only have a single sequence.
	if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
		c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
	}
	if int(c.swaMemorySize) >= capacity {
162
163
164
165
166
167
168
		c.swaMemorySize = math.MaxInt32
	}

	if c.swaMemorySize < c.swaWindowSize {
		panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
	}

169
	var cacheSize int
170
	if c.swaMemorySize == math.MaxInt32 {
171
172
		cacheSize = maxSequences * capacity
	} else {
173
		cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
174
	}
175
176
177
	cacheSize = roundUp(cacheSize, c.config.CachePadding)
	c.cells = make([]cacheCell, cacheSize)

Jesse Gross's avatar
Jesse Gross committed
178
179
180
	c.DType = dtype
	c.cellRanges = make(map[int]cellRange)
	c.backend = backend
181
	c.maxBatch = maxBatch
Jesse Gross's avatar
Jesse Gross committed
182
183
}

184
185
186
187
188
189
190
191
func (c *Causal) SetConfig(config ml.CacheConfig) {
	if c.config != nil {
		panic("config cannot be changed after being previously set, either by the model or backend")
	}

	c.config = &config
}

Jesse Gross's avatar
Jesse Gross committed
192
func (c *Causal) Close() {
193
194
195
	for _, ctx := range c.ctxs {
		ctx.Close()
	}
Jesse Gross's avatar
Jesse Gross committed
196
197
}

198
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
Jesse Gross's avatar
Jesse Gross committed
199
200
201
	c.curBatchSize = len(batch.Positions)
	c.curSequences = batch.Sequences
	c.curPositions = batch.Positions
202
	c.opts.Except = nil
Jesse Gross's avatar
Jesse Gross committed
203

204
	var locs []int32
205
	if !reserve {
206
		c.updateSlidingWindow()
207

208
		var err error
209
		locs, err = c.findLocs()
210
211
		if err != nil {
			return err
Jesse Gross's avatar
Jesse Gross committed
212
213
		}

214
215
		for i, pos := range batch.Positions {
			seq := batch.Sequences[i]
216
			loc := int(locs[i])
217

218
			c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
219
220
221
222
223
224

			seqRange, ok := c.cellRanges[seq]
			if !ok {
				seqRange = newRange()
			}

225
226
			seqRange.min = min(seqRange.min, loc)
			c.curCellRange.min = min(c.curCellRange.min, loc)
227

228
229
			seqRange.max = max(seqRange.max, loc)
			c.curCellRange.max = max(c.curCellRange.max, loc)
230
231

			c.cellRanges[seq] = seqRange
Jesse Gross's avatar
Jesse Gross committed
232
		}
233
234
235
	} else {
		// If we are reserving memory, don't update any of the cache metadata but set the size
		// to the worst case.
236
237
238
239
		locs = make([]int32, c.curBatchSize)
		for i := range locs {
			locs[i] = int32(i)
		}
240
241
		c.curCellRange.min = 0
		c.curCellRange.max = len(c.cells) - 1
Jesse Gross's avatar
Jesse Gross committed
242
243
	}

244
	c.curLoc = ctx.Input().FromInts(locs, len(locs))
245
	c.curMask = c.buildMask(ctx)
Jesse Gross's avatar
Jesse Gross committed
246

247
	return nil
Jesse Gross's avatar
Jesse Gross committed
248
249
250
251
252
253
254
255
256
}

func newRange() cellRange {
	return cellRange{
		min: math.MaxInt,
		max: 0,
	}
}

257
258
259
260
// Returns a slice of locations where each token in the batch should be stored
func (c *Causal) findLocs() ([]int32, error) {
	loc := make([]int32, 0, c.curBatchSize)

Jesse Gross's avatar
Jesse Gross committed
261
262
	for i := range c.cells {
		if len(c.cells[i].sequences) == 0 {
263
264
265
			loc = append(loc, int32(i))
			if len(loc) >= c.curBatchSize {
				return loc, nil
Jesse Gross's avatar
Jesse Gross committed
266
267
268
269
			}
		}
	}

270
	return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
Jesse Gross's avatar
Jesse Gross committed
271
272
}

273
func (c *Causal) updateSlidingWindow() {
274
275
276
277
278
279
280
281
282
283
	c.curCellRange = newRange()

	if c.swaMemorySize == math.MaxInt32 {
		for _, seq := range c.curSequences {
			if seqRange, ok := c.cellRanges[seq]; ok {
				c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
				c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
			}
		}

284
285
286
		return
	}

287
288
289
290
291
	type lowestPosition struct {
		pos      int32
		curBatch bool
	}

292
	// create a map of unique sequences to the lowest position in that sequence
293
	lowestPos := make(map[int]lowestPosition)
294
295
296
	for i := range c.curPositions {
		seq := c.curSequences[i]

297
		lowest, ok := lowestPos[seq]
298
		if !ok {
299
300
301
			lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
		} else if c.curPositions[i] < lowest.pos {
			lowest.pos = c.curPositions[i]
302
303
		}

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
		lowestPos[seq] = lowest
	}

	// for any sequences are not part of this batch, clean up any tokens
	// that are no longer needed after the processing of the previous
	// batch
	for seq, seqRange := range c.cellRanges {
		if _, ok := lowestPos[seq]; !ok {
			var last int32
			for i := seqRange.min; i <= seqRange.max; i++ {
				if slices.Contains(c.cells[i].sequences, seq) {
					last = max(last, c.cells[i].pos)
				}
			}

			lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
		}
321
322
323
	}

	// delete any entries that are beyond the window of the oldest position in the sequence
324
	for seq, lowest := range lowestPos {
325
326
327
328
329
330
331
332
333
		oldRange, ok := c.cellRanges[seq]
		if !ok {
			continue
		}

		newRange := newRange()

		for i := oldRange.min; i <= oldRange.max; i++ {
			if slices.Contains(c.cells[i].sequences, seq) {
334
				if c.cells[i].pos < lowest.pos-c.swaMemorySize {
335
336
337
338
339
					c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
				} else {
					newRange.min = min(newRange.min, i)
					newRange.max = max(newRange.max, i)
				}
340
				if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
341
342
343
					c.curCellRange.min = min(c.curCellRange.min, i)
					c.curCellRange.max = max(c.curCellRange.max, i)
				}
344
345
346
347
348
349
350
			}
		}

		c.cellRanges[seq] = newRange
	}
}

351
352
353
354
355
356
357
358
func roundDown(length, pad int) int {
	return (length / pad) * pad
}

func roundUp(length, pad int) int {
	return ((length + pad - 1) / pad) * pad
}

Jesse Gross's avatar
Jesse Gross committed
359
360
361
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
362
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
363
364
365
366
	c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
	c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1

	length := c.curCellRange.max - c.curCellRange.min + 1
367

368
	mask := make([]float32, c.curBatchSize*length)
Jesse Gross's avatar
Jesse Gross committed
369
370

	for i := range c.curBatchSize {
371
		enabled := !slices.Contains(c.opts.Except, i)
Jesse Gross's avatar
Jesse Gross committed
372
		for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
373
			if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
374
				(enabled && c.cells[j].pos > c.curPositions[i]) ||
Michael Yang's avatar
Michael Yang committed
375
				c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
376
				c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
377
				mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
Jesse Gross's avatar
Jesse Gross committed
378
379
380
381
			}
		}
	}

382
	maskTensor := ctx.Input().FromFloats(mask, length, c.curBatchSize)
383
384

	if c.config.MaskDType != ml.DTypeF32 {
385
		maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
386
387
	}

388
	return maskTensor
Jesse Gross's avatar
Jesse Gross committed
389
390
391
392
393
394
}

func (c *Causal) SetLayer(layer int) {
	c.curLayer = layer
}

395
type CausalOptions struct {
396
397
	// Enabled controls whether the causal mask is generated for a particular index in a batch
	Except []int
398
399
}

400
401
// SetCausal disables causal mask generation for a particular range of indicies in
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
402
403
404
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
	if !slices.Equal(c.opts.Except, opts.Except) {
		c.opts = opts
405
		if ctx != nil {
406
			c.curMask = c.buildMask(ctx)
407
408
409
410
		}
	}
}

Jesse Gross's avatar
Jesse Gross committed
411
412
413
414
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
	key := c.keys[c.curLayer]
	value := c.values[c.curLayer]

415
416
417
418
	kHeadDim := key.Dim(0)
	numKVHeads := key.Dim(1)
	rowSize := key.Stride(2)
	cachedSize := c.curMask.Dim(0)
Jesse Gross's avatar
Jesse Gross committed
419

420
421
422
423
	key = key.View(ctx, rowSize*c.curCellRange.min,
		kHeadDim, key.Stride(1),
		numKVHeads, key.Stride(2),
		cachedSize,
Jesse Gross's avatar
Jesse Gross committed
424
425
	)

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
	if c.config.PermutedV {
		vHeadDim := value.Dim(1)
		elemSize := value.Stride(0)

		value = value.View(ctx, elemSize*c.curCellRange.min,
			cachedSize, value.Stride(1),
			vHeadDim, value.Stride(2),
			numKVHeads,
		)
	} else {
		vHeadDim := value.Dim(0)
		rowSize := value.Stride(2)

		value = value.View(ctx, rowSize*c.curCellRange.min,
			vHeadDim, value.Stride(1),
			numKVHeads, value.Stride(2),
			cachedSize,
		)
	}

Jesse Gross's avatar
Jesse Gross committed
446
447
448
449
	return key, value, c.curMask
}

func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
450
451
452
453
454
455
456
	kHeadDim := key.Dim(0)
	vHeadDim := value.Dim(0)
	numKVHeads := key.Dim(1)
	batchSize := key.Dim(2)

	if c.curBatchSize != batchSize {
		panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
Jesse Gross's avatar
Jesse Gross committed
457
458
	}

459
	if _, ok := c.ctxs[c.curLayer]; !ok {
460
		c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
461
462
463
	}

	if _, ok := c.keys[c.curLayer]; !ok {
464
		c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
465
	}
466

467
	if _, ok := c.values[c.curLayer]; !ok {
468
		if c.config.PermutedV {
469
			c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
470
		} else {
471
			c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
472
		}
Jesse Gross's avatar
Jesse Gross committed
473
474
	}

475
476
477
478
	key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
	keyCache := c.keys[c.curLayer]
	keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
	ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
479
480

	if c.config.PermutedV {
481
482
483
484
485
		value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
		value = value.Permute(ctx, 2, 0, 1, 3)

		valueCache := c.values[c.curLayer]
		valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
486

487
		ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
488
	} else {
489
490
491
		value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
		valueCache := c.values[c.curLayer]
		valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
492

493
		ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
494
	}
Jesse Gross's avatar
Jesse Gross committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
}

func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
	seqRange := newRange()

	for i := range c.cells {
		// Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
		if slices.Contains(c.cells[i].sequences, dstSeq) {
			c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
		}

		if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
			c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
			if i < seqRange.min {
				seqRange.min = i
			}
			if i > seqRange.max {
				seqRange.max = i
			}
		}
	}

	c.cellRanges[dstSeq] = seqRange
}

520
func (c *Causal) CanResume(seq int, pos int32) bool {
521
	if c.swaMemorySize == math.MaxInt32 {
522
523
524
525
526
527
528
529
530
531
		return true
	}

	seqRange, ok := c.cellRanges[seq]
	if !ok {
		return false
	}

	// for sliding window, check that the window of the new sequence is contained in
	// the window of what we are storing
532
	var first int32 = math.MaxInt32
533
534
535
	var last int32 = -1
	for i := seqRange.min; i <= seqRange.max; i++ {
		if slices.Contains(c.cells[i].sequences, seq) {
536
			first = min(first, c.cells[i].pos)
537
538
539
540
541
542
543
544
			last = max(last, c.cells[i].pos)
		}
	}

	if last == -1 {
		return false
	}

545
	posWindowStart := max(0, pos-c.swaWindowSize)
546
	return posWindowStart >= first && pos <= last+1
547
548
}

Jesse Gross's avatar
Jesse Gross committed
549
550
551
552
553
554
555
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
	if c.shiftFn == nil {
		return ErrNotSupported
	}

	seqRange := c.cellRanges[seq]

556
557
558
	for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
		size := min(seqRange.max-start+1, c.maxBatch)
		offsets := make([]int32, size)
559
560
561
562

		var batchFirst, batchLast int

		batchFirst = -1
563
564
565
566
567
		for i := range offsets {
			cell := c.cells[start+i]

			if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
				offsets[i] = offset
568
569
570
571
				if batchFirst < 0 {
					batchFirst = i
				}
				batchLast = i
572
			}
Jesse Gross's avatar
Jesse Gross committed
573
574
		}

575
576
577
578
579
580
581
		if batchFirst < 0 {
			continue
		}

		offsets = offsets[batchFirst : batchLast+1]

		ctx := c.backend.NewContext()
Michael Yang's avatar
Michael Yang committed
582
		kShift := ctx.Input().FromInts(offsets, len(offsets))
Jesse Gross's avatar
Jesse Gross committed
583

584
585
586
587
		for i, key := range c.keys {
			if key == nil {
				continue
			}
Jesse Gross's avatar
Jesse Gross committed
588

589
590
591
			kHeadDim := key.Dim(0)
			numKVHeads := key.Dim(1)
			rowSize := key.Stride(2)
592

593
			key = key.View(ctx, rowSize*(start+batchFirst),
594
595
				kHeadDim, key.Stride(1),
				numKVHeads, key.Stride(2),
596
				len(offsets),
597
			)
Jesse Gross's avatar
Jesse Gross committed
598

599
600
601
602
603
604
605
			roped, err := c.shiftFn(ctx, i, key, kShift)
			if err != nil {
				ctx.Close()
				return err
			}

			ctx.Forward(roped.Copy(ctx, key))
Jesse Gross's avatar
Jesse Gross committed
606
607
		}

608
609
		ctx.Compute()
		ctx.Close()
Jesse Gross's avatar
Jesse Gross committed
610
611
612
613
614
615
	}

	return nil
}

func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
616
617
618
619
620
621
	// TODO(jessegross): We should check to see if removing the middle of the sequence will
	// cause the sliding window to encompass tokens that we no longer have. If so, then we
	// should return an error, which will trigger the runner to evaluate the full history and
	// rebuild the window. However, if we have multimodal inputs in our history, this reuse
	// results in use after free, so we don't do it for now.

Jesse Gross's avatar
Jesse Gross committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
	var offset int32
	if endIndex != math.MaxInt32 {
		offset = beginIndex - endIndex
	}

	seqRange := newRange()

	for i := range c.cells {
		if slices.Contains(c.cells[i].sequences, seq) {
			if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
				c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
			} else {
				if c.cells[i].pos >= endIndex {
					if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
636
						return errors.New("shifting cells shared by multiple sequences not supported")
Jesse Gross's avatar
Jesse Gross committed
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
					}

					c.cells[i].pos += offset
				}
				if i < seqRange.min {
					seqRange.min = i
				}
				if i > seqRange.max {
					seqRange.max = i
				}
			}
		}
	}

	if seqRange == newRange() {
		delete(c.cellRanges, seq)
		return nil
	}

	c.cellRanges[seq] = seqRange

	if endIndex != math.MaxInt32 {
		err := c.shift(seq, endIndex+offset, offset)
		if err != nil {
			return err
		}
	}

	return nil
}