Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
4346c240
Commit
4346c240
authored
Mar 07, 2025
by
Jesse Gross
Committed by
Michael Yang
Mar 11, 2025
Browse files
fix drift from main
parent
4b037a97
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
22 deletions
+49
-22
kvcache/causal_test.go
kvcache/causal_test.go
+4
-0
model/models/gemma2/model.go
model/models/gemma2/model.go
+25
-11
model/models/gemma3/model_text.go
model/models/gemma3/model_text.go
+17
-8
model/models/gemma3/model_vision.go
model/models/gemma3/model_vision.go
+1
-1
model/process_text_spm_test.go
model/process_text_spm_test.go
+2
-2
No files found.
kvcache/causal_test.go
View file @
4346c240
...
...
@@ -441,6 +441,10 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic
(
"not implemented"
)
}
func
(
t
*
testTensor
)
AvgPool1D
(
ctx
ml
.
Context
,
k
,
s
,
p
int
)
ml
.
Tensor
{
panic
(
"not implemented"
)
}
func
(
t
*
testTensor
)
Conv2D
(
ctx
ml
.
Context
,
weight
ml
.
Tensor
,
s0
,
s1
,
p0
,
p1
,
d0
,
d1
int
)
ml
.
Tensor
{
panic
(
"not implemented"
)
}
...
...
model/models/gemma2/model.go
View file @
4346c240
...
...
@@ -64,6 +64,7 @@ func New(c ml.Config) (model.Model, error) {
slidingWindowLen
:=
int32
(
c
.
Uint
(
"attention.sliding_window"
))
m
.
Cache
=
kvcache
.
NewWrapperCache
(
kvcache
.
NewSWACache
(
slidingWindowLen
,
m
.
Shift
),
kvcache
.
NewCausalCache
(
m
.
Shift
))
m
.
Cache
.
SetConfig
(
ml
.
CacheConfig
{})
return
&
m
,
nil
}
...
...
@@ -84,7 +85,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q
=
q
.
RoPE
(
ctx
,
positionIDs
,
nil
,
uint32
(
opts
.
attnKeyLen
),
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
if
opts
.
largeModelScaling
{
q
=
q
.
Scale
(
ctx
,
1.0
/
math
.
Sqrt
(
float64
(
opts
.
hiddenSize
/
opts
.
numHeads
)))
q
=
q
.
Scale
(
ctx
,
1.0
/
math
.
Sqrt
(
float64
(
opts
.
hiddenSize
/
opts
.
numHeads
)))
}
else
{
q
=
q
.
Scale
(
ctx
,
1.0
/
math
.
Sqrt
(
float64
(
opts
.
attnKeyLen
)))
}
...
...
@@ -99,8 +100,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
cache
.
Put
(
ctx
,
k
,
v
)
k
,
v
,
mask
:=
cache
.
Get
(
ctx
)
q
=
q
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
.
Contiguous
(
ctx
)
k
=
k
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
.
Contiguous
(
ctx
)
q
=
q
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
k
=
k
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
v
=
v
.
Permute
(
ctx
,
1
,
2
,
0
,
3
)
.
Contiguous
(
ctx
)
kq
:=
k
.
Mulmat
(
ctx
,
q
)
...
...
@@ -144,12 +145,20 @@ type Layer struct {
PostMLPNorm
*
nn
.
RMSNorm
`gguf:"post_ffw_norm"`
}
func
(
l
*
Layer
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
Options
)
ml
.
Tensor
{
func
(
l
*
Layer
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
,
outputs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
Options
)
ml
.
Tensor
{
residual
:=
hiddenState
hiddenState
=
l
.
AttentionNorm
.
Forward
(
ctx
,
hiddenState
,
opts
.
eps
)
hiddenState
=
l
.
SelfAttention
.
Forward
(
ctx
,
hiddenState
,
positionIDs
,
cache
,
opts
)
hiddenState
=
l
.
PostAttentionNorm
.
Forward
(
ctx
,
hiddenState
,
opts
.
eps
)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if
outputs
!=
nil
{
hiddenState
=
hiddenState
.
Rows
(
ctx
,
outputs
)
residual
=
residual
.
Rows
(
ctx
,
outputs
)
}
hiddenState
=
hiddenState
.
Add
(
ctx
,
residual
)
residual
=
hiddenState
...
...
@@ -170,6 +179,11 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return
nil
,
err
}
outputs
,
err
:=
ctx
.
FromIntSlice
(
opts
.
Outputs
,
len
(
opts
.
Outputs
))
if
err
!=
nil
{
return
nil
,
err
}
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
inputs
)
hiddenState
=
hiddenState
.
Scale
(
ctx
,
math
.
Sqrt
(
float64
(
m
.
Options
.
hiddenSize
)))
...
...
@@ -182,7 +196,13 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
m
.
Cache
.
SetLayer
(
i
)
wc
:=
m
.
Cache
.
(
*
kvcache
.
WrapperCache
)
wc
.
SetLayerType
(
cacheType
)
hiddenState
=
layer
.
Forward
(
ctx
,
hiddenState
,
positions
,
m
.
Cache
,
m
.
Options
)
var
lastLayerOutputs
ml
.
Tensor
if
i
==
len
(
m
.
Layers
)
-
1
{
lastLayerOutputs
=
outputs
}
hiddenState
=
layer
.
Forward
(
ctx
,
hiddenState
,
positions
,
lastLayerOutputs
,
m
.
Cache
,
m
.
Options
)
}
hiddenState
=
m
.
OutputNorm
.
Forward
(
ctx
,
hiddenState
,
m
.
eps
)
...
...
@@ -192,12 +212,6 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
hiddenState
=
hiddenState
.
Scale
(
ctx
,
1.0
/
float64
(
m
.
Options
.
finalLogitSoftcap
))
hiddenState
=
hiddenState
.
Tanh
(
ctx
)
hiddenState
=
hiddenState
.
Scale
(
ctx
,
float64
(
m
.
Options
.
finalLogitSoftcap
))
outputs
,
err
:=
ctx
.
Output
()
.
FromIntSlice
(
opts
.
Outputs
,
len
(
opts
.
Outputs
))
if
err
!=
nil
{
return
nil
,
err
}
return
hiddenState
.
Rows
(
ctx
,
outputs
),
nil
}
...
...
model/models/gemma3/model_text.go
View file @
4346c240
...
...
@@ -66,9 +66,6 @@ func newTextModel(c ml.Config) *TextModel {
},
}
slidingWindowLen
:=
int32
(
c
.
Uint
(
"text.attention.sliding_window"
))
m
.
Cache
=
kvcache
.
NewWrapperCache
(
kvcache
.
NewSWACache
(
slidingWindowLen
,
m
.
Shift
),
kvcache
.
NewCausalCache
(
m
.
Shift
))
return
&
m
}
...
...
@@ -145,12 +142,20 @@ type TextLayer struct {
PostMLPNorm
*
nn
.
RMSNorm
`gguf:"post_ffw_norm"`
}
func
(
l
*
TextLayer
)
Forward
(
ctx
ml
.
Context
,
layer
int
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
TextOptions
)
ml
.
Tensor
{
func
(
l
*
TextLayer
)
Forward
(
ctx
ml
.
Context
,
layer
int
,
hiddenState
,
positionIDs
,
outputs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
TextOptions
)
ml
.
Tensor
{
residual
:=
hiddenState
hiddenState
=
l
.
AttentionNorm
.
Forward
(
ctx
,
hiddenState
,
opts
.
eps
)
hiddenState
=
l
.
SelfAttention
.
Forward
(
ctx
,
layer
,
hiddenState
,
positionIDs
,
cache
,
opts
)
hiddenState
=
l
.
PostAttentionNorm
.
Forward
(
ctx
,
hiddenState
,
opts
.
eps
)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if
outputs
!=
nil
{
hiddenState
=
hiddenState
.
Rows
(
ctx
,
outputs
)
residual
=
residual
.
Rows
(
ctx
,
outputs
)
}
hiddenState
=
hiddenState
.
Add
(
ctx
,
residual
)
residual
=
hiddenState
...
...
@@ -181,7 +186,13 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu
cache
.
SetLayer
(
i
)
wc
:=
cache
.
(
*
kvcache
.
WrapperCache
)
wc
.
SetLayerType
(
cacheType
)
hiddenState
=
layer
.
Forward
(
ctx
,
i
,
hiddenState
,
positions
,
cache
,
m
.
TextOptions
)
var
lastLayerOutputs
ml
.
Tensor
if
i
==
len
(
m
.
Layers
)
-
1
{
lastLayerOutputs
=
outputs
}
hiddenState
=
layer
.
Forward
(
ctx
,
i
,
hiddenState
,
positions
,
lastLayerOutputs
,
cache
,
m
.
TextOptions
)
}
hiddenState
=
m
.
OutputNorm
.
Forward
(
ctx
,
hiddenState
,
m
.
eps
)
...
...
@@ -190,7 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu
// final logit softcap
hiddenState
=
hiddenState
.
Scale
(
ctx
,
1.0
/
float64
(
m
.
TextOptions
.
finalLogitSoftcap
))
hiddenState
=
hiddenState
.
Tanh
(
ctx
)
hiddenState
=
hiddenState
.
Scale
(
ctx
,
float64
(
m
.
TextOptions
.
finalLogitSoftcap
))
return
hiddenState
.
Rows
(
ctx
,
outputs
)
return
hiddenState
.
Scale
(
ctx
,
float64
(
m
.
TextOptions
.
finalLogitSoftcap
))
}
model/models/gemma3/model_vision.go
View file @
4346c240
model/process_text_spm_test.go
View file @
4346c240
...
...
@@ -73,7 +73,7 @@ func TestSentencePieceEncode(t *testing.T) {
}
for
_
,
want
:=
range
cases
{
ids
,
err
:=
tokenizer
.
Encode
(
want
)
ids
,
err
:=
tokenizer
.
Encode
(
want
,
true
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
...
...
@@ -98,7 +98,7 @@ func TestSentencePieceEncode(t *testing.T) {
}
for
_
,
want
:=
range
cases
{
ids
,
err
:=
tokenizer
.
Encode
(
want
.
token
)
ids
,
err
:=
tokenizer
.
Encode
(
want
.
token
,
true
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment