Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
8bf11b84
"model/vscode:/vscode.git/clone" did not exist on "7ba9fa9c7d0bc73abacca88d6827d973d7ba92cf"
Commit
8bf11b84
authored
Apr 10, 2025
by
Michael Yang
Committed by
Michael Yang
Apr 25, 2025
Browse files
chunked attention
parent
470af8ab
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
4 deletions
+84
-4
convert/convert_llama4.go
convert/convert_llama4.go
+2
-0
kvcache/causal.go
kvcache/causal.go
+13
-0
kvcache/causal_test.go
kvcache/causal_test.go
+68
-2
model/models/llama4/model.go
model/models/llama4/model.go
+1
-2
No files found.
convert/convert_llama4.go
View file @
8bf11b84
...
...
@@ -19,6 +19,7 @@ type llama4Model struct {
InterleaveMOELayerStep
uint32
`json:"interleave_moe_layer_step"`
UseQKNorm
bool
`json:"use_qk_norm"`
IntermediateSizeMLP
uint32
`json:"intermediate_size_mlp"`
AttentionChunkSize
uint32
`json:"attention_chunk_size"`
}
`json:"text_config"`
VisionModel
struct
{
NumHiddenLayers
uint32
`json:"num_hidden_layers"`
...
...
@@ -51,6 +52,7 @@ func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv
[
"llama4.expert_used_count"
]
=
p
.
TextModel
.
NumExpertsPerToken
kv
[
"llama4.interleave_moe_layer_step"
]
=
p
.
TextModel
.
InterleaveMOELayerStep
kv
[
"llama4.use_qk_norm"
]
=
p
.
TextModel
.
UseQKNorm
kv
[
"llama4.attention.chunk_size"
]
=
p
.
TextModel
.
AttentionChunkSize
kv
[
"llama4.vision.block_count"
]
=
p
.
VisionModel
.
NumHiddenLayers
kv
[
"llama4.vision.embedding_length"
]
=
p
.
VisionModel
.
HiddenSize
...
...
kvcache/causal.go
View file @
8bf11b84
...
...
@@ -21,6 +21,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type
Causal
struct
{
DType
ml
.
DType
windowSize
int32
chunkSize
int32
opts
CausalOptions
...
...
@@ -97,6 +98,17 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
}
func
NewChunkedAttentionCache
(
chunkSize
int32
,
shift
shiftFn
)
*
Causal
{
return
&
Causal
{
windowSize
:
math
.
MaxInt32
,
chunkSize
:
chunkSize
,
shiftFn
:
shift
,
ctxs
:
make
(
map
[
int
]
ml
.
Context
),
keys
:
make
(
map
[
int
]
ml
.
Tensor
),
values
:
make
(
map
[
int
]
ml
.
Tensor
),
}
}
func
(
c
*
Causal
)
Init
(
backend
ml
.
Backend
,
dtype
ml
.
DType
,
maxSequences
,
capacity
,
maxBatch
int
)
{
if
c
.
config
==
nil
{
var
config
ml
.
CacheConfig
...
...
@@ -300,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
for
j
:=
c
.
curCellRange
.
min
;
j
<=
c
.
curCellRange
.
max
;
j
++
{
if
!
slices
.
Contains
(
c
.
cells
[
j
]
.
sequences
,
c
.
curSequences
[
i
])
||
(
enabled
&&
c
.
cells
[
j
]
.
pos
>
c
.
curPositions
[
i
])
||
c
.
chunkSize
>
0
&&
c
.
cells
[
j
]
.
pos
<
c
.
curPositions
[
i
]
-
c
.
curPositions
[
i
]
%
c
.
chunkSize
||
c
.
cells
[
j
]
.
pos
<
c
.
curPositions
[
i
]
-
c
.
windowSize
{
mask
[
i
*
length
+
(
j
-
c
.
curCellRange
.
min
)]
=
float32
(
math
.
Inf
(
-
1
))
}
...
...
kvcache/causal_test.go
View file @
8bf11b84
...
...
@@ -86,6 +86,64 @@ func TestSWA(t *testing.T) {
testCache
(
t
,
backend
,
cache
,
tests
)
}
func
TestChunkedAttention
(
t
*
testing
.
T
)
{
cache
:=
NewChunkedAttentionCache
(
2
,
nil
)
defer
cache
.
Close
()
var
b
testBackend
cache
.
Init
(
&
b
,
ml
.
DTypeF16
,
1
,
16
,
16
)
x
:=
float32
(
math
.
Inf
(
-
1
))
testCache
(
t
,
&
b
,
cache
,
[]
testCase
{
{
name
:
"FirstBatch"
,
in
:
[]
float32
{
1
,
2
,
3
,
4
},
inShape
:
[]
int
{
1
,
1
,
4
},
seqs
:
[]
int
{
0
,
0
,
0
,
0
},
pos
:
[]
int32
{
0
,
1
,
2
,
3
},
expected
:
[]
float32
{
1
,
2
,
3
,
4
},
expectedShape
:
[]
int
{
1
,
1
,
4
},
expectedMask
:
[]
float32
{
0
,
x
,
x
,
x
,
0
,
0
,
x
,
x
,
x
,
x
,
0
,
x
,
x
,
x
,
0
,
0
,
},
},
{
name
:
"SecondBatch"
,
in
:
[]
float32
{
5
,
6
,
7
},
inShape
:
[]
int
{
1
,
1
,
3
},
seqs
:
[]
int
{
0
,
0
,
0
},
pos
:
[]
int32
{
4
,
5
,
6
},
expected
:
[]
float32
{
1
,
2
,
3
,
4
,
5
,
6
,
7
},
expectedShape
:
[]
int
{
1
,
1
,
7
},
expectedMask
:
[]
float32
{
x
,
x
,
x
,
x
,
0
,
x
,
x
,
x
,
x
,
x
,
x
,
0
,
0
,
x
,
x
,
x
,
x
,
x
,
x
,
x
,
0
,
},
},
{
name
:
"ThirdBatch"
,
in
:
[]
float32
{
8
,
9
},
inShape
:
[]
int
{
1
,
1
,
2
},
seqs
:
[]
int
{
0
,
0
},
pos
:
[]
int32
{
7
,
8
},
expected
:
[]
float32
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
},
expectedShape
:
[]
int
{
1
,
1
,
9
},
expectedMask
:
[]
float32
{
x
,
x
,
x
,
x
,
x
,
x
,
0
,
0
,
x
,
x
,
x
,
x
,
x
,
x
,
x
,
x
,
x
,
0
,
},
},
},
)
}
func
TestSequences
(
t
*
testing
.
T
)
{
backend
:=
&
testBackend
{}
cache
:=
NewCausalCache
(
nil
)
...
...
@@ -293,8 +351,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context
.
Forward
(
out
,
mask
)
.
Compute
(
out
,
mask
)
if
!
slices
.
Equal
(
out
.
Floats
(),
test
.
expected
)
||
!
slices
.
Equal
(
out
.
Shape
(),
test
.
expectedShape
)
||
!
slices
.
Equal
(
mask
.
Floats
(),
test
.
expectedMask
)
{
t
.
Errorf
(
"TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v"
,
out
.
Floats
(),
out
.
Shape
(),
test
.
expected
,
test
.
expectedShape
,
mask
.
Floats
(),
mask
.
Shape
(),
test
.
expectedMask
)
if
!
slices
.
Equal
(
out
.
Floats
(),
test
.
expected
)
{
t
.
Errorf
(
"TestCache: have %v; want %v"
,
out
.
Floats
(),
test
.
expected
)
}
if
!
slices
.
Equal
(
out
.
Shape
(),
test
.
expectedShape
)
{
t
.
Errorf
(
"TestCache: has shape %v; want %v"
,
out
.
Shape
(),
test
.
expectedShape
)
}
if
!
slices
.
Equal
(
mask
.
Floats
(),
test
.
expectedMask
)
{
t
.
Errorf
(
"TestCache: have mask: have %v want %v"
,
mask
.
Floats
(),
test
.
expectedMask
)
}
})
}
...
...
model/models/llama4/model.go
View file @
8bf11b84
...
...
@@ -52,8 +52,7 @@ func New(c fs.Config) (model.Model, error) {
}
m
.
Cache
=
kvcache
.
NewWrapperCache
(
// TODO: pretend this is chunked attention for now
kvcache
.
NewSWACache
(
8192
,
m
.
Shift
),
kvcache
.
NewChunkedAttentionCache
(
int32
(
c
.
Uint
(
"attention.chunk_size"
)),
m
.
Shift
),
kvcache
.
NewCausalCache
(
m
.
Shift
),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment