Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
8bf11b84
Commit
8bf11b84
authored
Apr 10, 2025
by
Michael Yang
Committed by
Michael Yang
Apr 25, 2025
Browse files
chunked attention
parent
470af8ab
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
4 deletions
+84
-4
convert/convert_llama4.go
convert/convert_llama4.go
+2
-0
kvcache/causal.go
kvcache/causal.go
+13
-0
kvcache/causal_test.go
kvcache/causal_test.go
+68
-2
model/models/llama4/model.go
model/models/llama4/model.go
+1
-2
No files found.
convert/convert_llama4.go
View file @
8bf11b84
...
@@ -19,6 +19,7 @@ type llama4Model struct {
...
@@ -19,6 +19,7 @@ type llama4Model struct {
InterleaveMOELayerStep
uint32
`json:"interleave_moe_layer_step"`
InterleaveMOELayerStep
uint32
`json:"interleave_moe_layer_step"`
UseQKNorm
bool
`json:"use_qk_norm"`
UseQKNorm
bool
`json:"use_qk_norm"`
IntermediateSizeMLP
uint32
`json:"intermediate_size_mlp"`
IntermediateSizeMLP
uint32
`json:"intermediate_size_mlp"`
AttentionChunkSize
uint32
`json:"attention_chunk_size"`
}
`json:"text_config"`
}
`json:"text_config"`
VisionModel
struct
{
VisionModel
struct
{
NumHiddenLayers
uint32
`json:"num_hidden_layers"`
NumHiddenLayers
uint32
`json:"num_hidden_layers"`
...
@@ -51,6 +52,7 @@ func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
...
@@ -51,6 +52,7 @@ func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv
[
"llama4.expert_used_count"
]
=
p
.
TextModel
.
NumExpertsPerToken
kv
[
"llama4.expert_used_count"
]
=
p
.
TextModel
.
NumExpertsPerToken
kv
[
"llama4.interleave_moe_layer_step"
]
=
p
.
TextModel
.
InterleaveMOELayerStep
kv
[
"llama4.interleave_moe_layer_step"
]
=
p
.
TextModel
.
InterleaveMOELayerStep
kv
[
"llama4.use_qk_norm"
]
=
p
.
TextModel
.
UseQKNorm
kv
[
"llama4.use_qk_norm"
]
=
p
.
TextModel
.
UseQKNorm
kv
[
"llama4.attention.chunk_size"
]
=
p
.
TextModel
.
AttentionChunkSize
kv
[
"llama4.vision.block_count"
]
=
p
.
VisionModel
.
NumHiddenLayers
kv
[
"llama4.vision.block_count"
]
=
p
.
VisionModel
.
NumHiddenLayers
kv
[
"llama4.vision.embedding_length"
]
=
p
.
VisionModel
.
HiddenSize
kv
[
"llama4.vision.embedding_length"
]
=
p
.
VisionModel
.
HiddenSize
...
...
kvcache/causal.go
View file @
8bf11b84
...
@@ -21,6 +21,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
...
@@ -21,6 +21,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type
Causal
struct
{
type
Causal
struct
{
DType
ml
.
DType
DType
ml
.
DType
windowSize
int32
windowSize
int32
chunkSize
int32
opts
CausalOptions
opts
CausalOptions
...
@@ -97,6 +98,17 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
...
@@ -97,6 +98,17 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
}
}
}
func
NewChunkedAttentionCache
(
chunkSize
int32
,
shift
shiftFn
)
*
Causal
{
return
&
Causal
{
windowSize
:
math
.
MaxInt32
,
chunkSize
:
chunkSize
,
shiftFn
:
shift
,
ctxs
:
make
(
map
[
int
]
ml
.
Context
),
keys
:
make
(
map
[
int
]
ml
.
Tensor
),
values
:
make
(
map
[
int
]
ml
.
Tensor
),
}
}
func
(
c
*
Causal
)
Init
(
backend
ml
.
Backend
,
dtype
ml
.
DType
,
maxSequences
,
capacity
,
maxBatch
int
)
{
func
(
c
*
Causal
)
Init
(
backend
ml
.
Backend
,
dtype
ml
.
DType
,
maxSequences
,
capacity
,
maxBatch
int
)
{
if
c
.
config
==
nil
{
if
c
.
config
==
nil
{
var
config
ml
.
CacheConfig
var
config
ml
.
CacheConfig
...
@@ -300,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
...
@@ -300,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
for
j
:=
c
.
curCellRange
.
min
;
j
<=
c
.
curCellRange
.
max
;
j
++
{
for
j
:=
c
.
curCellRange
.
min
;
j
<=
c
.
curCellRange
.
max
;
j
++
{
if
!
slices
.
Contains
(
c
.
cells
[
j
]
.
sequences
,
c
.
curSequences
[
i
])
||
if
!
slices
.
Contains
(
c
.
cells
[
j
]
.
sequences
,
c
.
curSequences
[
i
])
||
(
enabled
&&
c
.
cells
[
j
]
.
pos
>
c
.
curPositions
[
i
])
||
(
enabled
&&
c
.
cells
[
j
]
.
pos
>
c
.
curPositions
[
i
])
||
c
.
chunkSize
>
0
&&
c
.
cells
[
j
]
.
pos
<
c
.
curPositions
[
i
]
-
c
.
curPositions
[
i
]
%
c
.
chunkSize
||
c
.
cells
[
j
]
.
pos
<
c
.
curPositions
[
i
]
-
c
.
windowSize
{
c
.
cells
[
j
]
.
pos
<
c
.
curPositions
[
i
]
-
c
.
windowSize
{
mask
[
i
*
length
+
(
j
-
c
.
curCellRange
.
min
)]
=
float32
(
math
.
Inf
(
-
1
))
mask
[
i
*
length
+
(
j
-
c
.
curCellRange
.
min
)]
=
float32
(
math
.
Inf
(
-
1
))
}
}
...
...
kvcache/causal_test.go
View file @
8bf11b84
...
@@ -86,6 +86,64 @@ func TestSWA(t *testing.T) {
...
@@ -86,6 +86,64 @@ func TestSWA(t *testing.T) {
testCache
(
t
,
backend
,
cache
,
tests
)
testCache
(
t
,
backend
,
cache
,
tests
)
}
}
func
TestChunkedAttention
(
t
*
testing
.
T
)
{
cache
:=
NewChunkedAttentionCache
(
2
,
nil
)
defer
cache
.
Close
()
var
b
testBackend
cache
.
Init
(
&
b
,
ml
.
DTypeF16
,
1
,
16
,
16
)
x
:=
float32
(
math
.
Inf
(
-
1
))
testCache
(
t
,
&
b
,
cache
,
[]
testCase
{
{
name
:
"FirstBatch"
,
in
:
[]
float32
{
1
,
2
,
3
,
4
},
inShape
:
[]
int
{
1
,
1
,
4
},
seqs
:
[]
int
{
0
,
0
,
0
,
0
},
pos
:
[]
int32
{
0
,
1
,
2
,
3
},
expected
:
[]
float32
{
1
,
2
,
3
,
4
},
expectedShape
:
[]
int
{
1
,
1
,
4
},
expectedMask
:
[]
float32
{
0
,
x
,
x
,
x
,
0
,
0
,
x
,
x
,
x
,
x
,
0
,
x
,
x
,
x
,
0
,
0
,
},
},
{
name
:
"SecondBatch"
,
in
:
[]
float32
{
5
,
6
,
7
},
inShape
:
[]
int
{
1
,
1
,
3
},
seqs
:
[]
int
{
0
,
0
,
0
},
pos
:
[]
int32
{
4
,
5
,
6
},
expected
:
[]
float32
{
1
,
2
,
3
,
4
,
5
,
6
,
7
},
expectedShape
:
[]
int
{
1
,
1
,
7
},
expectedMask
:
[]
float32
{
x
,
x
,
x
,
x
,
0
,
x
,
x
,
x
,
x
,
x
,
x
,
0
,
0
,
x
,
x
,
x
,
x
,
x
,
x
,
x
,
0
,
},
},
{
name
:
"ThirdBatch"
,
in
:
[]
float32
{
8
,
9
},
inShape
:
[]
int
{
1
,
1
,
2
},
seqs
:
[]
int
{
0
,
0
},
pos
:
[]
int32
{
7
,
8
},
expected
:
[]
float32
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
},
expectedShape
:
[]
int
{
1
,
1
,
9
},
expectedMask
:
[]
float32
{
x
,
x
,
x
,
x
,
x
,
x
,
0
,
0
,
x
,
x
,
x
,
x
,
x
,
x
,
x
,
x
,
x
,
0
,
},
},
},
)
}
func
TestSequences
(
t
*
testing
.
T
)
{
func
TestSequences
(
t
*
testing
.
T
)
{
backend
:=
&
testBackend
{}
backend
:=
&
testBackend
{}
cache
:=
NewCausalCache
(
nil
)
cache
:=
NewCausalCache
(
nil
)
...
@@ -293,8 +351,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
...
@@ -293,8 +351,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context
.
Forward
(
out
,
mask
)
.
Compute
(
out
,
mask
)
context
.
Forward
(
out
,
mask
)
.
Compute
(
out
,
mask
)
if
!
slices
.
Equal
(
out
.
Floats
(),
test
.
expected
)
||
!
slices
.
Equal
(
out
.
Shape
(),
test
.
expectedShape
)
||
!
slices
.
Equal
(
mask
.
Floats
(),
test
.
expectedMask
)
{
if
!
slices
.
Equal
(
out
.
Floats
(),
test
.
expected
)
{
t
.
Errorf
(
"TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v"
,
out
.
Floats
(),
out
.
Shape
(),
test
.
expected
,
test
.
expectedShape
,
mask
.
Floats
(),
mask
.
Shape
(),
test
.
expectedMask
)
t
.
Errorf
(
"TestCache: have %v; want %v"
,
out
.
Floats
(),
test
.
expected
)
}
if
!
slices
.
Equal
(
out
.
Shape
(),
test
.
expectedShape
)
{
t
.
Errorf
(
"TestCache: has shape %v; want %v"
,
out
.
Shape
(),
test
.
expectedShape
)
}
if
!
slices
.
Equal
(
mask
.
Floats
(),
test
.
expectedMask
)
{
t
.
Errorf
(
"TestCache: have mask: have %v want %v"
,
mask
.
Floats
(),
test
.
expectedMask
)
}
}
})
})
}
}
...
...
model/models/llama4/model.go
View file @
8bf11b84
...
@@ -52,8 +52,7 @@ func New(c fs.Config) (model.Model, error) {
...
@@ -52,8 +52,7 @@ func New(c fs.Config) (model.Model, error) {
}
}
m
.
Cache
=
kvcache
.
NewWrapperCache
(
m
.
Cache
=
kvcache
.
NewWrapperCache
(
// TODO: pretend this is chunked attention for now
kvcache
.
NewChunkedAttentionCache
(
int32
(
c
.
Uint
(
"attention.chunk_size"
)),
m
.
Shift
),
kvcache
.
NewSWACache
(
8192
,
m
.
Shift
),
kvcache
.
NewCausalCache
(
m
.
Shift
),
kvcache
.
NewCausalCache
(
m
.
Shift
),
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment