Commit cb485b20 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

kvcache: Run tests both with and without PermutedV

The causal cache can store data differently depending on what is
best for the backend. We should run tests both ways.
parent b2af5096
package kvcache package kvcache
import ( import (
"fmt"
"math" "math"
"slices" "slices"
"testing" "testing"
...@@ -20,217 +21,184 @@ type testCase struct { ...@@ -20,217 +21,184 @@ type testCase struct {
expectedMask []float32 expectedMask []float32
} }
func TestStore(t *testing.T) { func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
backend := &testBackend{} t.Helper()
cache := NewCausalCache(nil) for _, permuted := range []bool{false, true} {
defer cache.Close() t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
fn(t, &testBackend{permutedV: permuted})
cache.Init(backend, ml.DTypeF16, 1, 16, 16) })
tests := []testCase{
{
name: "FirstBatch",
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
inShape: []int{2, 3, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
expectedShape: []int{2, 3, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
{
name: "SecondBatch",
in: []float32{115, 215, 125, 225, 135, 235},
inShape: []int{2, 3, 1},
seqs: []int{0},
pos: []int32{4},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
expectedShape: []int{2, 3, 5},
expectedMask: []float32{0, 0, 0, 0, 0},
},
} }
}
func TestStore(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(nil)
defer cache.Close()
testCache(t, backend, cache, tests) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
inShape: []int{2, 3, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
expectedShape: []int{2, 3, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
{
name: "SecondBatch",
in: []float32{115, 215, 125, 225, 135, 235},
inShape: []int{2, 3, 1},
seqs: []int{0},
pos: []int32{4},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
expectedShape: []int{2, 3, 5},
expectedMask: []float32{0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
})
} }
func TestSWA(t *testing.T) { func TestSWA(t *testing.T) {
backend := &testBackend{} runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWACache(1, nil) cache := NewSWACache(1, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1)) x := float32(math.Inf(-1))
tests := []testCase{ tests := []testCase{
{ {
name: "FirstBatch", name: "FirstBatch",
in: []float32{1, 2, 3, 4}, in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4}, inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0}, seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3}, pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4}, expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4}, expectedShape: []int{1, 1, 4},
expectedMask: []float32{ expectedMask: []float32{
0, x, x, x, 0, x, x, x,
0, 0, x, x, 0, 0, x, x,
x, 0, 0, x, x, 0, 0, x,
x, x, 0, 0, x, x, 0, 0,
},
}, },
}, {
{ name: "SecondBatch",
name: "SecondBatch", in: []float32{5, 6},
in: []float32{5, 6}, inShape: []int{1, 1, 2},
inShape: []int{1, 1, 2}, seqs: []int{0, 0},
seqs: []int{0, 0}, pos: []int32{4, 5},
pos: []int32{4, 5}, expected: []float32{5, 6, 3, 4},
expected: []float32{5, 6, 3, 4}, expectedShape: []int{1, 1, 4},
expectedShape: []int{1, 1, 4}, expectedMask: []float32{
expectedMask: []float32{ 0, x, x, 0,
0, x, x, 0, 0, 0, x, x,
0, 0, x, x, },
}, },
}, }
}
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
})
} }
func TestSWASeparateBatches(t *testing.T) { func TestSWASeparateBatches(t *testing.T) {
backend := &testBackend{} runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWACache(1, nil) cache := NewSWACache(1, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 2, 16, 2) cache.Init(backend, ml.DTypeF16, 2, 16, 2)
x := float32(math.Inf(-1)) x := float32(math.Inf(-1))
tests := []testCase{ tests := []testCase{
{ {
name: "First seq 0", name: "First seq 0",
in: []float32{1, 2}, in: []float32{1, 2},
inShape: []int{1, 1, 2}, inShape: []int{1, 1, 2},
seqs: []int{0, 0}, seqs: []int{0, 0},
pos: []int32{0, 1}, pos: []int32{0, 1},
expected: []float32{1, 2}, expected: []float32{1, 2},
expectedShape: []int{1, 1, 2}, expectedShape: []int{1, 1, 2},
expectedMask: []float32{ expectedMask: []float32{
0, x, 0, x,
0, 0, 0, 0,
},
}, },
}, {
{ name: "Second seq 0",
name: "Second seq 0", in: []float32{3, 4},
in: []float32{3, 4}, inShape: []int{1, 1, 2},
inShape: []int{1, 1, 2}, seqs: []int{0, 0},
seqs: []int{0, 0}, pos: []int32{2, 3},
pos: []int32{2, 3}, expected: []float32{2, 3, 4},
expected: []float32{2, 3, 4}, expectedShape: []int{1, 1, 3},
expectedShape: []int{1, 1, 3}, expectedMask: []float32{
expectedMask: []float32{ 0, 0, x,
0, 0, x, x, 0, 0,
x, 0, 0, },
}, },
}, {
{ name: "First seq 1",
name: "First seq 1", in: []float32{5, 6},
in: []float32{5, 6}, inShape: []int{1, 1, 2},
inShape: []int{1, 1, 2}, seqs: []int{1, 1},
seqs: []int{1, 1}, pos: []int32{0, 1},
pos: []int32{0, 1}, expected: []float32{5, 6},
expected: []float32{5, 6}, expectedShape: []int{1, 1, 2},
expectedShape: []int{1, 1, 2}, expectedMask: []float32{
expectedMask: []float32{ 0, x,
0, x, 0, 0,
0, 0, },
}, },
}, {
{ name: "Second seq 1",
name: "Second seq 1", in: []float32{7, 8},
in: []float32{7, 8}, inShape: []int{1, 1, 2},
inShape: []int{1, 1, 2}, seqs: []int{1, 1},
seqs: []int{1, 1}, pos: []int32{2, 3},
pos: []int32{2, 3}, expected: []float32{6, 3, 4, 7, 8},
expected: []float32{6, 3, 4, 7, 8}, expectedShape: []int{1, 1, 5},
expectedShape: []int{1, 1, 5}, expectedMask: []float32{
expectedMask: []float32{ 0, x, x, 0, x,
0, x, x, 0, x, x, x, x, 0, 0,
x, x, x, 0, 0, },
}, },
}, {
{ name: "Third seq 0",
name: "Third seq 0", in: []float32{9, 10},
in: []float32{9, 10}, inShape: []int{1, 1, 2},
inShape: []int{1, 1, 2}, seqs: []int{0, 0},
seqs: []int{0, 0}, pos: []int32{4, 5},
pos: []int32{4, 5}, expected: []float32{9, 10, 3, 4},
expected: []float32{9, 10, 3, 4}, expectedShape: []int{1, 1, 4},
expectedShape: []int{1, 1, 4}, expectedMask: []float32{
expectedMask: []float32{ 0, x, x, 0,
0, x, x, 0, 0, 0, x, x,
0, 0, x, x, },
}, },
}, }
}
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
})
} }
func TestSWAMem(t *testing.T) { func TestSWAMem(t *testing.T) {
backend := &testBackend{} runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWAMemCache(1, 3, nil) cache := NewSWAMemCache(1, 3, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, 0, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 2, 3, 4, 6},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, x, x, 0, x,
0, x, x, x, 0,
},
},
}
testCache(t, backend, cache, tests)
}
func TestChunkedAttention(t *testing.T) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
var b testBackend cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1)) x := float32(math.Inf(-1))
testCache( tests := []testCase{
t, &b, cache,
[]testCase{
{ {
name: "FirstBatch", name: "FirstBatch",
in: []float32{1, 2, 3, 4}, in: []float32{1, 2, 3, 4},
...@@ -242,190 +210,240 @@ func TestChunkedAttention(t *testing.T) { ...@@ -242,190 +210,240 @@ func TestChunkedAttention(t *testing.T) {
expectedMask: []float32{ expectedMask: []float32{
0, x, x, x, 0, x, x, x,
0, 0, x, x, 0, 0, x, x,
x, x, 0, x, x, 0, 0, x,
x, x, 0, 0, x, x, 0, 0,
}, },
}, },
{ {
name: "SecondBatch", name: "SecondBatch",
in: []float32{5, 6, 7}, in: []float32{5, 6},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2}, inShape: []int{1, 1, 2},
seqs: []int{0, 0}, seqs: []int{0, 0},
pos: []int32{7, 8}, pos: []int32{4, 5},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, expected: []float32{5, 2, 3, 4, 6},
expectedShape: []int{1, 1, 9}, expectedShape: []int{1, 1, 5},
expectedMask: []float32{ expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x, 0, x, x, 0, x,
x, x, x, x, x, x, x, x, 0, 0, x, x, x, 0,
},
},
}
testCache(t, backend, cache, tests)
})
}
func TestChunkedAttention(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, backend, cache,
[]testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6, 7},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{7, 8},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedShape: []int{1, 1, 9},
expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x,
x, x, x, x, x, x, x, x, 0,
},
}, },
}, },
}, )
) })
} }
func TestSequences(t *testing.T) { func TestSequences(t *testing.T) {
backend := &testBackend{} runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{2, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{2, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
})
} }
func TestRemove(t *testing.T) { func TestRemove(t *testing.T) {
backend := &testBackend{} runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil return key.Add(ctx, shift), nil
}) })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1)) x := float32(math.Inf(-1))
tests := []testCase{ tests := []testCase{
{ {
name: "FirstBatch", name: "FirstBatch",
in: []float32{1, 2, 3, 4}, in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4}, inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1}, seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1}, pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4}, expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4}, expectedShape: []int{1, 1, 4},
expectedMask: []float32{ expectedMask: []float32{
0, x, x, x, 0, x, x, x,
0, 0, x, x, 0, 0, x, x,
x, x, 0, x, x, x, 0, x,
x, x, 0, 0, x, x, 0, 0,
},
}, },
}, }
}
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
err := cache.Remove(0, 1, math.MaxInt32) err := cache.Remove(0, 1, math.MaxInt32)
if err != nil { if err != nil {
panic(err) panic(err)
} }
tests = []testCase{ tests = []testCase{
{ {
name: "RemoveEnd", name: "RemoveEnd",
in: []float32{5, 6}, in: []float32{5, 6},
inShape: []int{1, 1, 2}, inShape: []int{1, 1, 2},
seqs: []int{0, 1}, seqs: []int{0, 1},
pos: []int32{1, 2}, pos: []int32{1, 2},
expected: []float32{1, 5, 3, 4, 6}, expected: []float32{1, 5, 3, 4, 6},
expectedShape: []int{1, 1, 5}, expectedShape: []int{1, 1, 5},
expectedMask: []float32{ expectedMask: []float32{
0, 0, x, x, x, 0, 0, x, x, x,
x, x, 0, 0, 0, x, x, 0, 0, 0,
},
}, },
}, }
}
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
err = cache.Remove(0, 0, 1) err = cache.Remove(0, 0, 1)
if err != nil { if err != nil {
panic(err) panic(err)
} }
tests = []testCase{ tests = []testCase{
{ {
name: "RemoveMiddle", name: "RemoveMiddle",
in: []float32{7, 8}, in: []float32{7, 8},
inShape: []int{1, 1, 2}, inShape: []int{1, 1, 2},
seqs: []int{0, 0}, seqs: []int{0, 0},
pos: []int32{1, 2}, pos: []int32{1, 2},
expected: []float32{7, 4, 3, 4, 6, 8}, expected: []float32{7, 4, 3, 4, 6, 8},
expectedShape: []int{1, 1, 6}, expectedShape: []int{1, 1, 6},
expectedMask: []float32{ expectedMask: []float32{
0, 0, x, x, x, x, 0, 0, x, x, x, x,
0, 0, x, x, x, 0, 0, 0, x, x, x, 0,
},
}, },
}, }
}
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
})
} }
func TestCopy(t *testing.T) { func TestCopy(t *testing.T) {
backend := &testBackend{} runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.CopyPrefix(0, 1, 2) tests := []testCase{
{
tests = []testCase{ name: "FirstBatch",
{ in: []float32{1, 2, 3, 4},
name: "Copy", inShape: []int{1, 1, 4},
in: []float32{5, 6}, seqs: []int{0, 0, 0, 0},
inShape: []int{1, 1, 2}, pos: []int32{0, 1, 2, 3},
seqs: []int{1, 1}, expected: []float32{1, 2, 3, 4},
pos: []int32{3, 4}, expectedShape: []int{1, 1, 4},
expected: []float32{1, 2, 3, 4, 5, 6}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
expectedShape: []int{1, 1, 6}, },
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, }
},
}
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
cache.CopyPrefix(0, 1, 2)
tests = []testCase{
{
name: "Copy",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{3, 4},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
testCache(t, backend, cache, tests)
})
} }
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) { func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
...@@ -463,145 +481,148 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) ...@@ -463,145 +481,148 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
} }
func TestCanResume(t *testing.T) { func TestCanResume(t *testing.T) {
backend := &testBackend{} runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
windowSize := int32(4) windowSize := int32(4)
cache := NewSWACache(windowSize, nil) cache := NewSWACache(windowSize, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext() context := backend.NewContext()
defer context.Close() defer context.Close()
err := cache.StartForward(context, input.Batch{ err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4}, Positions: []int32{0, 1, 2, 3, 4},
Sequences: []int{0, 0, 0, 0, 0}, Sequences: []int{0, 0, 0, 0, 0},
}, false) }, false)
if err != nil { if err != nil {
t.Fatalf("StartForward failed: %v", err) t.Fatalf("StartForward failed: %v", err)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5) tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet // with window size 4, nothing has slid out of the window yet
if !cache.CanResume(0, 0) { if !cache.CanResume(0, 0) {
t.Errorf("CanResume(0, 0) = false, want true (within window)") t.Errorf("CanResume(0, 0) = false, want true (within window)")
} }
if !cache.CanResume(0, 1) { if !cache.CanResume(0, 1) {
t.Errorf("CanResume(0, 1) = false, want true (within window)") t.Errorf("CanResume(0, 1) = false, want true (within window)")
} }
if !cache.CanResume(0, 2) { if !cache.CanResume(0, 2) {
t.Errorf("CanResume(0, 2) = false, want true (within window)") t.Errorf("CanResume(0, 2) = false, want true (within window)")
} }
if !cache.CanResume(0, 3) { if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)") t.Errorf("CanResume(0, 3) = false, want true (latest position)")
} }
if !cache.CanResume(0, 4) { if !cache.CanResume(0, 4) {
t.Errorf("CanResume(0, 4) = false, want true (latest position)") t.Errorf("CanResume(0, 4) = false, want true (latest position)")
} }
// shift window by adding position 5 // shift window by adding position 5
err = cache.StartForward(context, input.Batch{ err = cache.StartForward(context, input.Batch{
Positions: []int32{5}, Positions: []int32{5},
Sequences: []int{0}, Sequences: []int{0},
}, false) }, false)
if err != nil { if err != nil {
t.Fatalf("StartForward failed: %v", err) t.Fatalf("StartForward failed: %v", err)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor = context.FromFloats([]float32{6}, 1, 1, 1) tensor = context.FromFloats([]float32{6}, 1, 1, 1)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows // only the latest position has overlapping windows
if cache.CanResume(0, 0) { if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
} }
if cache.CanResume(0, 1) { if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
} }
if cache.CanResume(0, 2) { if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
} }
if cache.CanResume(0, 3) { if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
} }
if cache.CanResume(0, 4) { if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
} }
if !cache.CanResume(0, 5) { if !cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
} }
})
} }
func TestCanResumeSWAMem(t *testing.T) { func TestCanResumeSWAMem(t *testing.T) {
backend := &testBackend{} runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
windowSize := int32(4) windowSize := int32(4)
memSize := int32(5) memSize := int32(5)
cache := NewSWAMemCache(windowSize, memSize, nil) cache := NewSWAMemCache(windowSize, memSize, nil)
defer cache.Close() defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16) cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext() context := backend.NewContext()
defer context.Close() defer context.Close()
err := cache.StartForward(context, input.Batch{ err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4, 5, 6}, Positions: []int32{0, 1, 2, 3, 4, 5, 6},
Sequences: []int{0, 0, 0, 0, 0, 0, 0}, Sequences: []int{0, 0, 0, 0, 0, 0, 0},
}, false) }, false)
if err != nil { if err != nil {
t.Fatalf("StartForward failed: %v", err) t.Fatalf("StartForward failed: %v", err)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// shift window by adding position 7 // shift window by adding position 7
err = cache.StartForward(context, input.Batch{ err = cache.StartForward(context, input.Batch{
Positions: []int32{7}, Positions: []int32{7},
Sequences: []int{0}, Sequences: []int{0},
}, false) }, false)
if err != nil { if err != nil {
t.Fatalf("StartForward failed: %v", err) t.Fatalf("StartForward failed: %v", err)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor = context.FromFloats([]float32{8}, 1, 1, 1) tensor = context.FromFloats([]float32{8}, 1, 1, 1)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows // only the latest position has overlapping windows
if cache.CanResume(0, 0) { if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
} }
if cache.CanResume(0, 1) { if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
} }
if cache.CanResume(0, 2) { if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
} }
if cache.CanResume(0, 3) { if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
} }
if cache.CanResume(0, 4) { if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
} }
if cache.CanResume(0, 5) { if cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)") t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
} }
if !cache.CanResume(0, 6) { if !cache.CanResume(0, 6) {
t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)") t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
} }
if !cache.CanResume(0, 7) { if !cache.CanResume(0, 7) {
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)") t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
} }
})
} }
type testBackend struct { type testBackend struct {
ml.Backend ml.Backend
permutedV bool
} }
func (b *testBackend) NewContext() ml.Context { func (b *testBackend) NewContext() ml.Context {
...@@ -612,6 +633,10 @@ func (b *testBackend) NewContextSize(int) ml.Context { ...@@ -612,6 +633,10 @@ func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{} return &testContext{}
} }
func (b *testBackend) CacheConfig() ml.CacheConfig {
return ml.CacheConfig{PermutedV: b.permutedV}
}
type testContext struct { type testContext struct {
ml.Context ml.Context
} }
...@@ -766,6 +791,102 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { ...@@ -766,6 +791,102 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
return view return view
} }
func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
if len(t.shape) > 4 || len(order) > 4 {
panic("permute only supports up to 4 dimensions")
}
if len(order) != len(t.shape) && len(order) != 4 {
panic("invalid number of dimensions for permute")
}
// ggml_permute expects 4 axes, so fill in any missing dimensions.
orderFull := append(make([]int, 0, 4), order...)
for len(orderFull) < 4 {
orderFull = append(orderFull, len(orderFull))
}
seen := [4]bool{}
shape4 := [4]int{1, 1, 1, 1}
for i := 0; i < len(t.shape) && i < 4; i++ {
shape4[i] = t.shape[i]
}
newShape4 := [4]int{1, 1, 1, 1}
for axis := range 4 {
dst := orderFull[axis]
if dst < 0 || dst >= 4 {
panic("invalid axis for permute")
}
if seen[dst] {
panic("duplicate axis for permute")
}
seen[dst] = true
newShape4[dst] = shape4[axis]
}
total := len(t.data)
newData := make([]float32, total)
if total > 0 {
oldDims := shape4
newDims := newShape4
oldStride := [4]int{1, 1, 1, 1}
newStride := [4]int{1, 1, 1, 1}
for i := 1; i < 4; i++ {
oldStride[i] = oldStride[i-1] * oldDims[i-1]
newStride[i] = newStride[i-1] * newDims[i-1]
}
var coords [4]int
var newCoords [4]int
for idx := range total {
remainder := idx
for axis := range 4 {
dim := oldDims[axis]
if dim == 0 {
coords[axis] = 0
continue
}
coords[axis] = remainder % dim
remainder /= dim
}
for axis := range 4 {
newCoords[orderFull[axis]] = coords[axis]
}
newIndex := 0
for axis := range 4 {
if newDims[axis] == 0 {
continue
}
newIndex += newCoords[axis] * newStride[axis]
}
newData[newIndex] = t.data[idx]
}
}
numDims := 4
for numDims > 1 && newShape4[numDims-1] <= 1 {
numDims--
}
newShape := make([]int, numDims)
copy(newShape, newShape4[:numDims])
return &testTensor{
dtype: t.dtype,
elementSize: t.elementSize,
data: newData,
shape: newShape,
}
}
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor { func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
dst := t dst := t
srcTensor := src.(*testTensor) srcTensor := src.(*testTensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment