Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
e9527893
Commit
e9527893
authored
Mar 10, 2025
by
Michael Yang
Browse files
use non-causal mask only for image positions
parent
9d2a20a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
10 deletions
+18
-10
kvcache/causal.go
kvcache/causal.go
+12
-8
model/models/gemma3/model_text.go
model/models/gemma3/model_text.go
+6
-2
No files found.
kvcache/causal.go
View file @
e9527893
...
@@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
...
@@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type
Causal
struct
{
type
Causal
struct
{
DType
ml
.
DType
DType
ml
.
DType
Capacity
int32
Capacity
int32
causal
bool
windowSize
int32
windowSize
int32
opts
CausalOptions
// config controls mostly backend-specific optimizations
// config controls mostly backend-specific optimizations
config
*
ml
.
CacheConfig
config
*
ml
.
CacheConfig
...
@@ -79,7 +80,6 @@ type cellRange struct {
...
@@ -79,7 +80,6 @@ type cellRange struct {
func
NewCausalCache
(
shift
shiftFn
)
*
Causal
{
func
NewCausalCache
(
shift
shiftFn
)
*
Causal
{
return
&
Causal
{
return
&
Causal
{
causal
:
true
,
windowSize
:
math
.
MaxInt32
,
windowSize
:
math
.
MaxInt32
,
shiftFn
:
shift
,
shiftFn
:
shift
,
ctxs
:
make
(
map
[
int
]
ml
.
Context
),
ctxs
:
make
(
map
[
int
]
ml
.
Context
),
...
@@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
...
@@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
func
NewSWACache
(
windowSize
int32
,
shift
shiftFn
)
*
Causal
{
func
NewSWACache
(
windowSize
int32
,
shift
shiftFn
)
*
Causal
{
return
&
Causal
{
return
&
Causal
{
causal
:
true
,
windowSize
:
windowSize
,
windowSize
:
windowSize
,
shiftFn
:
shift
,
shiftFn
:
shift
,
ctxs
:
make
(
map
[
int
]
ml
.
Context
),
ctxs
:
make
(
map
[
int
]
ml
.
Context
),
...
@@ -235,9 +234,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
...
@@ -235,9 +234,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
mask
:=
make
([]
float32
,
batchSize
*
length
)
mask
:=
make
([]
float32
,
batchSize
*
length
)
for
i
:=
range
c
.
curBatchSize
{
for
i
:=
range
c
.
curBatchSize
{
enabled
:=
!
slices
.
Contains
(
c
.
opts
.
Except
,
c
.
curPositions
[
i
])
for
j
:=
c
.
curCellRange
.
min
;
j
<=
c
.
curCellRange
.
max
;
j
++
{
for
j
:=
c
.
curCellRange
.
min
;
j
<=
c
.
curCellRange
.
max
;
j
++
{
if
!
slices
.
Contains
(
c
.
cells
[
j
]
.
sequences
,
c
.
curSequences
[
i
])
||
if
!
slices
.
Contains
(
c
.
cells
[
j
]
.
sequences
,
c
.
curSequences
[
i
])
||
(
c
.
causal
&&
c
.
cells
[
j
]
.
pos
>
c
.
curPositions
[
i
])
||
(
enabled
&&
c
.
cells
[
j
]
.
pos
>
c
.
curPositions
[
i
])
||
c
.
cells
[
j
]
.
pos
<
c
.
curPositions
[
i
]
-
c
.
windowSize
{
c
.
cells
[
j
]
.
pos
<
c
.
curPositions
[
i
]
-
c
.
windowSize
{
mask
[
i
*
length
+
(
j
-
c
.
curCellRange
.
min
)]
=
float32
(
math
.
Inf
(
-
1
))
mask
[
i
*
length
+
(
j
-
c
.
curCellRange
.
min
)]
=
float32
(
math
.
Inf
(
-
1
))
}
}
...
@@ -404,15 +404,19 @@ func (c *Causal) SetLayer(layer int) {
...
@@ -404,15 +404,19 @@ func (c *Causal) SetLayer(layer int) {
c
.
curLayer
=
layer
c
.
curLayer
=
layer
}
}
type
CausalOptions
struct
{
// Enabled controls whether the causal mask is generated for a particular position.
Except
[]
int32
}
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
// This state carries over to future forward passes. The default value is true.
// This state carries over to future forward passes. The default value is true.
//
//
// ctx may be set to nil if this is called from outside of a forward pass, for
// ctx may be set to nil if this is called from outside of a forward pass, for
// example, when initializing the cache.
// example, when initializing the cache.
func
(
c
*
Causal
)
SetCausal
(
ctx
ml
.
Context
,
causal
bool
)
{
func
(
c
*
Causal
)
SetCausal
(
ctx
ml
.
Context
,
opts
CausalOptions
)
{
if
c
.
causal
!=
causal
{
if
!
slices
.
Equal
(
c
.
opts
.
Except
,
opts
.
Except
)
{
c
.
causal
=
causal
c
.
opts
=
opts
if
ctx
!=
nil
{
if
ctx
!=
nil
{
var
err
error
var
err
error
c
.
curMask
,
err
=
c
.
buildMask
(
ctx
)
c
.
curMask
,
err
=
c
.
buildMask
(
ctx
)
...
...
model/models/gemma3/model_text.go
View file @
e9527893
...
@@ -183,8 +183,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
...
@@ -183,8 +183,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenState
=
hiddenState
.
Set
(
ctx
,
visionOutputs
,
offset
*
hiddenState
.
Stride
(
1
))
hiddenState
=
hiddenState
.
Set
(
ctx
,
visionOutputs
,
offset
*
hiddenState
.
Stride
(
1
))
if
causal
,
ok
:=
cache
.
(
*
kvcache
.
WrapperCache
)
.
UnderlyingCache
()
.
(
*
kvcache
.
Causal
);
ok
{
if
causal
,
ok
:=
cache
.
(
*
kvcache
.
WrapperCache
)
.
UnderlyingCache
()
.
(
*
kvcache
.
Causal
);
ok
{
causal
.
SetCausal
(
ctx
,
false
)
except
:=
make
([]
int32
,
visionOutputs
.
Dim
(
1
))
defer
causal
.
SetCausal
(
ctx
,
true
)
for
i
:=
0
;
i
<
visionOutputs
.
Dim
(
1
);
i
++
{
except
[
i
]
=
int32
(
offset
+
i
)
}
causal
.
SetCausal
(
ctx
,
kvcache
.
CausalOptions
{
Except
:
except
})
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment