"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "97ec8cfd4ef13190f3939fbb24b6f146d570ed12"
Commit 1c093e97 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

kvcache: Remove special case for reservation mask

We currently short circuit generation of the cache mask and just
generate an empty tensor of the correct size. However, in some
cases, this can also skip a cast operation. This can result in the
worst case graph being not fully worst case.

We don't actually need the fast path for mask generation, so it's
better to just use the normal code path.
parent a8d9c264
...@@ -40,11 +40,6 @@ type Causal struct { ...@@ -40,11 +40,6 @@ type Causal struct {
// ** current forward pass ** // ** current forward pass **
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// the active layer for Get and Put // the active layer for Get and Put
curLayer int curLayer int
...@@ -206,13 +201,12 @@ func (c *Causal) Close() { ...@@ -206,13 +201,12 @@ func (c *Causal) Close() {
} }
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curReserve = reserve
c.curBatchSize = len(batch.Positions) c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences c.curSequences = batch.Sequences
c.curPositions = batch.Positions c.curPositions = batch.Positions
c.opts.Except = nil c.opts.Except = nil
if !c.curReserve { if !reserve {
c.updateSlidingWindow() c.updateSlidingWindow()
var err error var err error
...@@ -379,10 +373,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { ...@@ -379,10 +373,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
length := c.curCellRange.max - c.curCellRange.min + 1 length := c.curCellRange.max - c.curCellRange.min + 1
if c.curReserve {
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
}
mask := make([]float32, batchSize*length) mask := make([]float32, batchSize*length)
for i := range c.curBatchSize { for i := range c.curBatchSize {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment