Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
7ba9fa9c
Commit
7ba9fa9c
authored
Apr 21, 2025
by
Michael Yang
Committed by
Michael Yang
Apr 25, 2025
Browse files
fixes for maverick
parent
8bf11b84
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
26 deletions
+53
-26
convert/convert_llama4.go
convert/convert_llama4.go
+2
-2
model/models/llama4/model.go
model/models/llama4/model.go
+3
-2
model/models/llama4/model_text.go
model/models/llama4/model_text.go
+48
-22
No files found.
convert/convert_llama4.go
View file @
7ba9fa9c
...
@@ -45,8 +45,8 @@ func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
...
@@ -45,8 +45,8 @@ func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
}
}
}
}
kv
[
"llama4.
intermediate_size
"
]
=
p
.
TextModel
.
IntermediateSizeMLP
kv
[
"llama4.
feed_forward_length
"
]
=
p
.
TextModel
.
IntermediateSizeMLP
kv
[
"llama4.
intermediate_size_moe
"
]
=
p
.
TextModel
.
IntermediateSize
kv
[
"llama4.
expert_feed_forward_length
"
]
=
p
.
TextModel
.
IntermediateSize
kv
[
"llama4.expert_count"
]
=
p
.
TextModel
.
NumLocalExperts
kv
[
"llama4.expert_count"
]
=
p
.
TextModel
.
NumLocalExperts
kv
[
"llama4.expert_used_count"
]
=
p
.
TextModel
.
NumExpertsPerToken
kv
[
"llama4.expert_used_count"
]
=
p
.
TextModel
.
NumExpertsPerToken
...
...
model/models/llama4/model.go
View file @
7ba9fa9c
...
@@ -35,7 +35,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
...
@@ -35,7 +35,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func
New
(
c
fs
.
Config
)
(
model
.
Model
,
error
)
{
func
New
(
c
fs
.
Config
)
(
model
.
Model
,
error
)
{
m
:=
Model
{
m
:=
Model
{
BytePairEncoding
:
model
.
NewBytePairEncoding
(
BytePairEncoding
:
model
.
NewBytePairEncoding
(
c
.
String
(
"tokenizer.ggml.pretokenizer"
,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
),
c
.
String
(
"tokenizer.ggml.pretokenizer"
,
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`
),
&
model
.
Vocabulary
{
&
model
.
Vocabulary
{
Values
:
c
.
Strings
(
"tokenizer.ggml.tokens"
),
Values
:
c
.
Strings
(
"tokenizer.ggml.tokens"
),
Types
:
c
.
Uints
(
"tokenizer.ggml.token_type"
),
Types
:
c
.
Uints
(
"tokenizer.ggml.token_type"
),
...
@@ -52,7 +53,7 @@ func New(c fs.Config) (model.Model, error) {
...
@@ -52,7 +53,7 @@ func New(c fs.Config) (model.Model, error) {
}
}
m
.
Cache
=
kvcache
.
NewWrapperCache
(
m
.
Cache
=
kvcache
.
NewWrapperCache
(
kvcache
.
NewChunkedAttentionCache
(
int32
(
c
.
Uint
(
"attention.chunk_size"
)),
m
.
Shift
),
kvcache
.
NewChunkedAttentionCache
(
int32
(
c
.
Uint
(
"attention.chunk_size"
,
8192
)),
m
.
Shift
),
kvcache
.
NewCausalCache
(
m
.
Shift
),
kvcache
.
NewCausalCache
(
m
.
Shift
),
)
)
...
...
model/models/llama4/model_text.go
View file @
7ba9fa9c
...
@@ -19,7 +19,7 @@ type TextAttention struct {
...
@@ -19,7 +19,7 @@ type TextAttention struct {
RopeFactors
ml
.
Tensor
`gguf:"rope_factors"`
RopeFactors
ml
.
Tensor
`gguf:"rope_factors"`
}
}
func
(
sa
*
TextAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
positions
ml
.
Tensor
,
cache
kvcache
.
Cache
,
useRope
bool
,
opts
*
TextOptions
)
ml
.
Tensor
{
func
(
sa
*
TextAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
positions
,
attentionScales
ml
.
Tensor
,
cache
kvcache
.
Cache
,
useRope
bool
,
opts
*
TextOptions
)
ml
.
Tensor
{
batchSize
,
headDim
:=
hiddenStates
.
Dim
(
1
),
cmp
.
Or
(
opts
.
headDim
,
opts
.
hiddenSize
/
opts
.
numHeads
)
batchSize
,
headDim
:=
hiddenStates
.
Dim
(
1
),
cmp
.
Or
(
opts
.
headDim
,
opts
.
hiddenSize
/
opts
.
numHeads
)
query
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenStates
)
query
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenStates
)
...
@@ -33,11 +33,15 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
...
@@ -33,11 +33,15 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
if
useRope
{
if
useRope
{
query
=
query
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
uint32
(
opts
.
ropeDim
),
uint32
(
0
),
opts
.
ropeBase
,
opts
.
ropeScale
)
query
=
query
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
uint32
(
opts
.
ropeDim
),
uint32
(
0
),
opts
.
ropeBase
,
opts
.
ropeScale
)
key
=
key
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
uint32
(
opts
.
ropeDim
),
uint32
(
0
),
opts
.
ropeBase
,
opts
.
ropeScale
)
key
=
key
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
uint32
(
opts
.
ropeDim
),
uint32
(
0
),
opts
.
ropeBase
,
opts
.
ropeScale
)
}
if
opts
.
useQKNorm
{
if
opts
.
useQKNorm
{
query
=
query
.
RMSNorm
(
ctx
,
nil
,
opts
.
eps
)
query
=
query
.
RMSNorm
(
ctx
,
nil
,
opts
.
eps
)
key
=
key
.
RMSNorm
(
ctx
,
nil
,
opts
.
eps
)
key
=
key
.
RMSNorm
(
ctx
,
nil
,
opts
.
eps
)
}
}
if
attentionScales
!=
nil
&&
!
useRope
{
query
=
query
.
Mul
(
ctx
,
attentionScales
)
}
}
attention
:=
nn
.
Attention
(
ctx
,
query
,
key
,
value
,
1.
/
math
.
Sqrt
(
float64
(
headDim
)),
cache
)
attention
:=
nn
.
Attention
(
ctx
,
query
,
key
,
value
,
1.
/
math
.
Sqrt
(
float64
(
headDim
)),
cache
)
...
@@ -82,7 +86,7 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
...
@@ -82,7 +86,7 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
return
nextStates
return
nextStates
}
}
// TextSharedExpert is TextMLP with different names
// TextSharedExpert is TextMLP with different
tensor
names
type
TextSharedExpert
struct
{
type
TextSharedExpert
struct
{
Gate
*
nn
.
Linear
`gguf:"ffn_gate_shexp"`
Gate
*
nn
.
Linear
`gguf:"ffn_gate_shexp"`
Up
*
nn
.
Linear
`gguf:"ffn_up_shexp"`
Up
*
nn
.
Linear
`gguf:"ffn_up_shexp"`
...
@@ -122,12 +126,12 @@ type TextLayer struct {
...
@@ -122,12 +126,12 @@ type TextLayer struct {
FeedForward
TextFeedForward
FeedForward
TextFeedForward
}
}
func
(
d
*
TextLayer
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
positions
,
outputs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
useRope
bool
,
opts
*
TextOptions
)
ml
.
Tensor
{
func
(
d
*
TextLayer
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
positions
,
attentionScales
,
outputs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
useRope
bool
,
opts
*
TextOptions
)
ml
.
Tensor
{
residual
:=
hiddenStates
residual
:=
hiddenStates
// self attention
// self attention
hiddenStates
=
d
.
AttentionNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
hiddenStates
=
d
.
AttentionNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
hiddenStates
=
d
.
Attention
.
Forward
(
ctx
,
hiddenStates
,
positions
,
cache
,
useRope
,
opts
)
hiddenStates
=
d
.
Attention
.
Forward
(
ctx
,
hiddenStates
,
positions
,
attentionScales
,
cache
,
useRope
,
opts
)
if
outputs
!=
nil
{
if
outputs
!=
nil
{
hiddenStates
=
hiddenStates
.
Rows
(
ctx
,
outputs
)
hiddenStates
=
hiddenStates
.
Rows
(
ctx
,
outputs
)
...
@@ -151,7 +155,11 @@ type TextOptions struct {
...
@@ -151,7 +155,11 @@ type TextOptions struct {
ropeBase
,
ropeScale
float32
ropeBase
,
ropeScale
float32
eps
float32
eps
float32
interleaveLayerStep
int
interleaveLayerStep
int
noRopeInterval
int
useQKNorm
bool
useQKNorm
bool
attentionTemperatureTuning
bool
attentionScale
float64
attentionFloorScale
float64
}
}
type
TextModel
struct
{
type
TextModel
struct
{
...
@@ -178,18 +186,22 @@ func newTextModel(c fs.Config) *TextModel {
...
@@ -178,18 +186,22 @@ func newTextModel(c fs.Config) *TextModel {
return
&
TextModel
{
return
&
TextModel
{
Layers
:
layers
,
Layers
:
layers
,
TextOptions
:
&
TextOptions
{
TextOptions
:
&
TextOptions
{
hiddenSize
:
int
(
c
.
Uint
(
"embedding_length"
)),
hiddenSize
:
int
(
c
.
Uint
(
"embedding_length"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
headDim
:
int
(
c
.
Uint
(
"attention.head_dim"
,
128
)),
headDim
:
int
(
c
.
Uint
(
"attention.head_dim"
,
128
)),
numExperts
:
int
(
c
.
Uint
(
"expert_count"
)),
numExperts
:
int
(
c
.
Uint
(
"expert_count"
)),
numExpertsUsed
:
int
(
c
.
Uint
(
"expert_used_count"
)),
numExpertsUsed
:
int
(
c
.
Uint
(
"expert_used_count"
)),
ropeDim
:
int
(
c
.
Uint
(
"rope.dimension_count"
)),
ropeDim
:
int
(
c
.
Uint
(
"rope.dimension_count"
)),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
interleaveLayerStep
:
int
(
c
.
Uint
(
"interleave_moe_layer_step"
,
1
)),
interleaveLayerStep
:
int
(
c
.
Uint
(
"interleave_moe_layer_step"
,
1
)),
useQKNorm
:
c
.
Bool
(
"use_qk_norm"
,
true
),
noRopeInterval
:
int
(
c
.
Uint
(
"no_rope_interval"
,
4
)),
useQKNorm
:
c
.
Bool
(
"use_qk_norm"
,
true
),
attentionTemperatureTuning
:
c
.
Bool
(
"attention.temperature_tuning"
,
true
),
attentionScale
:
float64
(
c
.
Float
(
"attention.scale"
,
0.1
)),
attentionFloorScale
:
float64
(
c
.
Float
(
"attention.floor_scale"
,
8192
)),
},
},
}
}
}
}
...
@@ -207,11 +219,25 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
...
@@ -207,11 +219,25 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
ctx
.
Forward
(
img
.
Copy
(
ctx
,
hiddenStates
.
View
(
ctx
,
mi
.
Index
*
hiddenStates
.
Stride
(
1
),
img
.
Dim
(
0
)
*
img
.
Dim
(
1
))))
ctx
.
Forward
(
img
.
Copy
(
ctx
,
hiddenStates
.
View
(
ctx
,
mi
.
Index
*
hiddenStates
.
Stride
(
1
),
img
.
Dim
(
0
)
*
img
.
Dim
(
1
))))
}
}
var
attentionScales
ml
.
Tensor
if
m
.
attentionTemperatureTuning
{
scales
:=
make
([]
float32
,
len
(
batch
.
Positions
))
for
i
,
p
:=
range
batch
.
Positions
{
scales
[
i
]
=
float32
(
math
.
Log
(
math
.
Floor
(((
float64
(
p
)
+
1.0
)
/
float64
(
m
.
attentionFloorScale
))
+
1.0
))
*
m
.
attentionScale
+
1.0
)
}
var
err
error
attentionScales
,
err
=
ctx
.
Input
()
.
FromFloatSlice
(
scales
,
1
,
1
,
len
(
scales
))
if
err
!=
nil
{
panic
(
err
)
}
}
for
i
,
layer
:=
range
m
.
Layers
{
for
i
,
layer
:=
range
m
.
Layers
{
cache
.
SetLayer
(
i
)
cache
.
SetLayer
(
i
)
wc
:=
cache
.
(
*
kvcache
.
WrapperCache
)
wc
:=
cache
.
(
*
kvcache
.
WrapperCache
)
wc
.
SetLayerType
(
1
)
wc
.
SetLayerType
(
1
)
useChunkedAttention
:=
(
i
+
1
)
%
4
!=
0
useChunkedAttention
:=
(
i
+
1
)
%
m
.
noRopeInterval
!=
0
if
useChunkedAttention
{
if
useChunkedAttention
{
wc
.
SetLayerType
(
0
)
wc
.
SetLayerType
(
0
)
}
}
...
@@ -221,7 +247,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
...
@@ -221,7 +247,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
lastLayerOutputs
=
outputs
lastLayerOutputs
=
outputs
}
}
hiddenStates
=
layer
.
Forward
(
ctx
,
hiddenStates
,
positions
,
lastLayerOutputs
,
cache
,
useChunkedAttention
,
m
.
TextOptions
)
hiddenStates
=
layer
.
Forward
(
ctx
,
hiddenStates
,
positions
,
attentionScales
,
lastLayerOutputs
,
cache
,
useChunkedAttention
,
m
.
TextOptions
)
}
}
hiddenStates
=
m
.
OutputNorm
.
Forward
(
ctx
,
hiddenStates
,
m
.
eps
)
hiddenStates
=
m
.
OutputNorm
.
Forward
(
ctx
,
hiddenStates
,
m
.
eps
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment