Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
f0c66e6d
Commit
f0c66e6d
authored
Apr 03, 2025
by
Michael Yang
Committed by
Michael Yang
Apr 25, 2025
Browse files
llama4
parent
54055a6d
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
833 additions
and
15 deletions
+833
-15
convert/convert.go
convert/convert.go
+2
-0
convert/convert_llama.go
convert/convert_llama.go
+10
-3
convert/convert_llama4.go
convert/convert_llama4.go
+167
-0
convert/reader.go
convert/reader.go
+9
-7
convert/reader_safetensors.go
convert/reader_safetensors.go
+15
-0
convert/reader_torch.go
convert/reader_torch.go
+11
-0
fs/ggml/ggml.go
fs/ggml/ggml.go
+1
-0
ml/backend.go
ml/backend.go
+4
-0
ml/backend/ggml/ggml.go
ml/backend/ggml/ggml.go
+34
-5
model/models/llama4/model.go
model/models/llama4/model.go
+100
-0
model/models/llama4/model_text.go
model/models/llama4/model_text.go
+223
-0
model/models/llama4/model_vision.go
model/models/llama4/model_vision.go
+256
-0
model/models/models.go
model/models/models.go
+1
-0
No files found.
convert/convert.go
View file @
f0c66e6d
...
@@ -173,6 +173,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
...
@@ -173,6 +173,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
switch
p
.
Architectures
[
0
]
{
switch
p
.
Architectures
[
0
]
{
case
"LlamaForCausalLM"
:
case
"LlamaForCausalLM"
:
conv
=
&
llamaModel
{}
conv
=
&
llamaModel
{}
case
"Llama4ForConditionalGeneration"
:
conv
=
&
llama4Model
{}
case
"Mistral3ForConditionalGeneration"
:
case
"Mistral3ForConditionalGeneration"
:
conv
=
&
mistral3Model
{}
conv
=
&
mistral3Model
{}
case
"MixtralForCausalLM"
:
case
"MixtralForCausalLM"
:
...
...
convert/convert_llama.go
View file @
f0c66e6d
...
@@ -42,6 +42,8 @@ type llamaModel struct {
...
@@ -42,6 +42,8 @@ type llamaModel struct {
LayerNormEpsilon
float32
`json:"layer_norm_epsilon"`
LayerNormEpsilon
float32
`json:"layer_norm_epsilon"`
NormEpsilon
float32
`json:"norm_epsilon"`
NormEpsilon
float32
`json:"norm_epsilon"`
HeadDim
uint32
`json:"head_dim"`
HeadDim
uint32
`json:"head_dim"`
skipRepack
bool
}
}
var
_
ModelConverter
=
(
*
llamaModel
)(
nil
)
var
_
ModelConverter
=
(
*
llamaModel
)(
nil
)
...
@@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
...
@@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv
[
"llama.rope.dimension_count"
]
=
p
.
HiddenSize
/
headCount
kv
[
"llama.rope.dimension_count"
]
=
p
.
HiddenSize
/
headCount
}
}
if
p
.
HeadDim
>
0
{
kv
[
"llama.attention.head_dim"
]
=
p
.
HeadDim
}
if
p
.
RopeTheta
>
0
{
if
p
.
RopeTheta
>
0
{
kv
[
"llama.rope.freq_base"
]
=
p
.
RopeTheta
kv
[
"llama.rope.freq_base"
]
=
p
.
RopeTheta
}
}
...
@@ -133,9 +139,10 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
...
@@ -133,9 +139,10 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
}
}
for
_
,
t
:=
range
ts
{
for
_
,
t
:=
range
ts
{
if
strings
.
HasSuffix
(
t
.
Name
(),
"attn_q.weight"
)
||
if
strings
.
HasSuffix
(
t
.
Name
(),
"attn_q.weight"
)
||
strings
.
HasSuffix
(
t
.
Name
(),
"attn_k.weight"
)
{
strings
.
HasSuffix
(
t
.
Name
(),
"attn_k.weight"
)
{
if
!
p
.
skipRepack
{
t
.
SetRepacker
(
p
.
repack
)
t
.
SetRepacker
(
p
.
repack
)
}
}
}
out
=
append
(
out
,
ggml
.
Tensor
{
out
=
append
(
out
,
ggml
.
Tensor
{
...
...
convert/convert_llama4.go
0 → 100644
View file @
f0c66e6d
package
convert
import
(
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type
llama4Model
struct
{
ModelParameters
TextModel
struct
{
llamaModel
NumExpertsPerToken
uint32
`json:"num_experts_per_tok"`
NumLocalExperts
uint32
`json:"num_local_experts"`
InterleaveMOELayerStep
uint32
`json:"interleave_moe_layer_step"`
UseQKNorm
bool
`json:"use_qk_norm"`
IntermediateSizeMLP
uint32
`json:"intermediate_size_mlp"`
}
`json:"text_config"`
VisionModel
struct
{
NumHiddenLayers
uint32
`json:"num_hidden_layers"`
HiddenSize
uint32
`json:"hidden_size"`
IntermediateSize
uint32
`json:"intermediate_size"`
NumAttentionHeads
uint32
`json:"num_attention_heads"`
ImageSize
uint32
`json:"image_size"`
PatchSize
uint32
`json:"patch_size"`
RopeTheta
float32
`json:"rope_theta"`
NormEpsilon
float32
`json:"norm_eps"`
PixelShuffleRatio
float32
`json:"pixel_shuffle_ratio"`
}
`json:"vision_config"`
}
// KV implements ModelConverter.
func
(
p
*
llama4Model
)
KV
(
t
*
Tokenizer
)
ggml
.
KV
{
kv
:=
p
.
ModelParameters
.
KV
(
t
)
kv
[
"general.architecture"
]
=
"llama4"
for
k
,
v
:=
range
p
.
TextModel
.
KV
(
t
)
{
if
strings
.
HasPrefix
(
k
,
"llama."
)
{
kv
[
strings
.
ReplaceAll
(
k
,
"llama."
,
"llama4."
)]
=
v
}
}
kv
[
"llama4.intermediate_size"
]
=
p
.
TextModel
.
IntermediateSizeMLP
kv
[
"llama4.intermediate_size_moe"
]
=
p
.
TextModel
.
IntermediateSize
kv
[
"llama4.expert_count"
]
=
p
.
TextModel
.
NumLocalExperts
kv
[
"llama4.expert_used_count"
]
=
p
.
TextModel
.
NumExpertsPerToken
kv
[
"llama4.interleave_moe_layer_step"
]
=
p
.
TextModel
.
InterleaveMOELayerStep
kv
[
"llama4.use_qk_norm"
]
=
p
.
TextModel
.
UseQKNorm
kv
[
"llama4.vision.block_count"
]
=
p
.
VisionModel
.
NumHiddenLayers
kv
[
"llama4.vision.embedding_length"
]
=
p
.
VisionModel
.
HiddenSize
kv
[
"llama4.vision.feed_forward_length"
]
=
p
.
VisionModel
.
IntermediateSize
kv
[
"llama4.vision.attention.head_count"
]
=
p
.
VisionModel
.
NumAttentionHeads
kv
[
"llama4.vision.image_size"
]
=
p
.
VisionModel
.
ImageSize
kv
[
"llama4.vision.patch_size"
]
=
p
.
VisionModel
.
PatchSize
kv
[
"llama4.vision.rope.freq_base"
]
=
p
.
VisionModel
.
RopeTheta
kv
[
"llama4.vision.layer_norm_epsilon"
]
=
p
.
VisionModel
.
NormEpsilon
kv
[
"llama4.vision.pixel_shuffle_ratio"
]
=
p
.
VisionModel
.
PixelShuffleRatio
return
kv
}
// Replacements implements ModelConverter.
func
(
p
*
llama4Model
)
Replacements
()
[]
string
{
return
append
(
p
.
TextModel
.
Replacements
(),
"language_model."
,
""
,
"vision_model"
,
"v"
,
"multi_modal_projector"
,
"mm"
,
"feed_forward.down_proj"
,
"ffn_down"
,
"feed_forward.up_proj"
,
"ffn_up"
,
"feed_forward.gate_proj"
,
"ffn_gate"
,
"feed_forward."
,
"ffn_"
,
"shared_expert.down_proj"
,
"down_shexp"
,
"shared_expert.gate_proj"
,
"gate_shexp"
,
"shared_expert.up_proj"
,
"up_shexp"
,
"experts.down_proj"
,
"down_exps.weight"
,
"experts.gate_up_proj"
,
"gate_up_exps.weight"
,
"router"
,
"gate_inp"
,
"patch_embedding.linear"
,
"patch_embedding"
,
)
}
// Tensors implements ModelConverter.
func
(
p
*
llama4Model
)
Tensors
(
ts
[]
Tensor
)
[]
ggml
.
Tensor
{
var
out
[]
ggml
.
Tensor
var
textTensors
[]
Tensor
for
_
,
t
:=
range
ts
{
if
strings
.
HasPrefix
(
t
.
Name
(),
"v."
)
||
strings
.
HasPrefix
(
t
.
Name
(),
"mm."
)
{
out
=
append
(
out
,
ggml
.
Tensor
{
Name
:
t
.
Name
(),
Kind
:
t
.
Kind
(),
Shape
:
t
.
Shape
(),
WriterTo
:
t
,
})
}
else
if
strings
.
Contains
(
t
.
Name
(),
"ffn_gate_up_exps"
)
{
// gate and up projectors are fused
// dims[1], dims[2] must be swapped
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
halfDim
:=
int
(
t
.
Shape
()[
2
])
/
2
newShape
:=
slices
.
Clone
(
t
.
Shape
())
newShape
[
1
],
newShape
[
2
]
=
newShape
[
2
]
/
2
,
newShape
[
1
]
for
i
,
name
:=
range
[]
string
{
"ffn_gate_exps"
,
"ffn_up_exps"
}
{
// clone tensor since we need separate repackers
tt
:=
t
.
Clone
()
tt
.
SetRepacker
(
p
.
repack
(
nil
,
nil
,
tensor
.
S
(
i
*
halfDim
,
(
i
+
1
)
*
halfDim
)))
out
=
append
(
out
,
ggml
.
Tensor
{
Name
:
strings
.
ReplaceAll
(
tt
.
Name
(),
"ffn_gate_up_exps"
,
name
),
Kind
:
tt
.
Kind
(),
Shape
:
newShape
,
WriterTo
:
tt
,
})
}
}
else
if
strings
.
Contains
(
t
.
Name
(),
"ffn_down_exps"
)
{
// dims[1], dims[2] must be swapped
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
t
.
SetRepacker
(
p
.
repack
())
newShape
:=
slices
.
Clone
(
t
.
Shape
())
newShape
[
1
],
newShape
[
2
]
=
newShape
[
2
],
newShape
[
1
]
out
=
append
(
out
,
ggml
.
Tensor
{
Name
:
t
.
Name
(),
Kind
:
t
.
Kind
(),
Shape
:
newShape
,
WriterTo
:
t
,
})
}
else
{
textTensors
=
append
(
textTensors
,
t
)
}
}
p
.
TextModel
.
skipRepack
=
true
out
=
append
(
out
,
p
.
TextModel
.
Tensors
(
textTensors
)
...
)
return
out
}
func
(
p
*
llama4Model
)
repack
(
slice
...
tensor
.
Slice
)
Repacker
{
return
func
(
name
string
,
data
[]
float32
,
shape
[]
uint64
)
([]
float32
,
error
)
{
dims
:=
make
([]
int
,
len
(
shape
))
for
i
,
dim
:=
range
shape
{
dims
[
i
]
=
int
(
dim
)
}
var
t
tensor
.
Tensor
=
tensor
.
New
(
tensor
.
WithShape
(
dims
...
),
tensor
.
WithBacking
(
data
))
t
,
err
:=
t
.
Slice
(
slice
...
)
if
err
!=
nil
{
return
nil
,
err
}
if
err
:=
t
.
T
(
0
,
2
,
1
);
err
!=
nil
{
return
nil
,
err
}
t
=
tensor
.
Materialize
(
t
)
// flatten tensor so it can be return as a vector
if
err
:=
t
.
Reshape
(
t
.
Shape
()
.
TotalSize
());
err
!=
nil
{
return
nil
,
err
}
return
native
.
VectorF32
(
t
.
(
*
tensor
.
Dense
))
}
}
convert/reader.go
View file @
f0c66e6d
...
@@ -11,14 +11,15 @@ type Tensor interface {
...
@@ -11,14 +11,15 @@ type Tensor interface {
Name
()
string
Name
()
string
Shape
()
[]
uint64
Shape
()
[]
uint64
Kind
()
uint32
Kind
()
uint32
SetRepacker
(
r
epacker
)
SetRepacker
(
R
epacker
)
WriteTo
(
io
.
Writer
)
(
int64
,
error
)
WriteTo
(
io
.
Writer
)
(
int64
,
error
)
Clone
()
Tensor
}
}
type
tensorBase
struct
{
type
tensorBase
struct
{
name
string
name
string
shape
[]
uint64
shape
[]
uint64
repacker
repacker
Repacker
}
}
func
(
t
tensorBase
)
Name
()
string
{
func
(
t
tensorBase
)
Name
()
string
{
...
@@ -36,7 +37,8 @@ const (
...
@@ -36,7 +37,8 @@ const (
func
(
t
tensorBase
)
Kind
()
uint32
{
func
(
t
tensorBase
)
Kind
()
uint32
{
if
strings
.
HasSuffix
(
t
.
name
,
".ffn_gate_inp.weight"
)
||
if
strings
.
HasSuffix
(
t
.
name
,
".ffn_gate_inp.weight"
)
||
t
.
name
==
"token_types.weight"
{
t
.
name
==
"token_types.weight"
||
t
.
name
==
"v.positional_embedding_vlm"
{
// these tensors are always F32
// these tensors are always F32
return
0
return
0
}
}
...
@@ -51,11 +53,11 @@ func (t tensorBase) Kind() uint32 {
...
@@ -51,11 +53,11 @@ func (t tensorBase) Kind() uint32 {
}
}
}
}
func
(
t
*
tensorBase
)
SetRepacker
(
fn
r
epacker
)
{
func
(
t
*
tensorBase
)
SetRepacker
(
fn
R
epacker
)
{
t
.
repacker
=
fn
t
.
repacker
=
fn
}
}
type
r
epacker
func
(
string
,
[]
float32
,
[]
uint64
)
([]
float32
,
error
)
type
R
epacker
func
(
string
,
[]
float32
,
[]
uint64
)
([]
float32
,
error
)
func
parseTensors
(
fsys
fs
.
FS
,
replacer
*
strings
.
Replacer
)
([]
Tensor
,
error
)
{
func
parseTensors
(
fsys
fs
.
FS
,
replacer
*
strings
.
Replacer
)
([]
Tensor
,
error
)
{
patterns
:=
[]
struct
{
patterns
:=
[]
struct
{
...
...
convert/reader_safetensors.go
View file @
f0c66e6d
...
@@ -94,6 +94,21 @@ type safetensor struct {
...
@@ -94,6 +94,21 @@ type safetensor struct {
*
tensorBase
*
tensorBase
}
}
func
(
st
safetensor
)
Clone
()
Tensor
{
return
&
safetensor
{
fs
:
st
.
fs
,
path
:
st
.
path
,
dtype
:
st
.
dtype
,
offset
:
st
.
offset
,
size
:
st
.
size
,
tensorBase
:
&
tensorBase
{
name
:
st
.
name
,
repacker
:
st
.
repacker
,
shape
:
slices
.
Clone
(
st
.
shape
),
},
}
}
func
(
st
safetensor
)
WriteTo
(
w
io
.
Writer
)
(
int64
,
error
)
{
func
(
st
safetensor
)
WriteTo
(
w
io
.
Writer
)
(
int64
,
error
)
{
f
,
err
:=
st
.
fs
.
Open
(
st
.
path
)
f
,
err
:=
st
.
fs
.
Open
(
st
.
path
)
if
err
!=
nil
{
if
err
!=
nil
{
...
...
convert/reader_torch.go
View file @
f0c66e6d
...
@@ -43,6 +43,17 @@ type torch struct {
...
@@ -43,6 +43,17 @@ type torch struct {
*
tensorBase
*
tensorBase
}
}
func
(
t
torch
)
Clone
()
Tensor
{
return
torch
{
storage
:
t
.
storage
,
tensorBase
:
&
tensorBase
{
name
:
t
.
name
,
shape
:
t
.
shape
,
repacker
:
t
.
repacker
,
},
}
}
func
(
pt
torch
)
WriteTo
(
w
io
.
Writer
)
(
int64
,
error
)
{
func
(
pt
torch
)
WriteTo
(
w
io
.
Writer
)
(
int64
,
error
)
{
return
0
,
nil
return
0
,
nil
}
}
fs/ggml/ggml.go
View file @
f0c66e6d
...
@@ -124,6 +124,7 @@ func (kv KV) OllamaEngineRequired() bool {
...
@@ -124,6 +124,7 @@ func (kv KV) OllamaEngineRequired() bool {
return
slices
.
Contains
([]
string
{
return
slices
.
Contains
([]
string
{
"gemma3"
,
"gemma3"
,
"mistral3"
,
"mistral3"
,
"llama4"
,
},
kv
.
Architecture
())
},
kv
.
Architecture
())
}
}
...
...
ml/backend.go
View file @
f0c66e6d
...
@@ -133,6 +133,7 @@ type Tensor interface {
...
@@ -133,6 +133,7 @@ type Tensor interface {
Mul
(
ctx
Context
,
t2
Tensor
)
Tensor
Mul
(
ctx
Context
,
t2
Tensor
)
Tensor
Mulmat
(
ctx
Context
,
t2
Tensor
)
Tensor
Mulmat
(
ctx
Context
,
t2
Tensor
)
Tensor
MulmatFullPrec
(
ctx
Context
,
t2
Tensor
)
Tensor
MulmatFullPrec
(
ctx
Context
,
t2
Tensor
)
Tensor
MulmatID
(
ctx
Context
,
t2
,
ids
Tensor
)
Tensor
Softmax
(
ctx
Context
)
Tensor
Softmax
(
ctx
Context
)
Tensor
LayerNorm
(
ctx
Context
,
weight
,
bias
Tensor
,
eps
float32
)
Tensor
LayerNorm
(
ctx
Context
,
weight
,
bias
Tensor
,
eps
float32
)
Tensor
...
@@ -150,6 +151,7 @@ type Tensor interface {
...
@@ -150,6 +151,7 @@ type Tensor interface {
Tanh
(
ctx
Context
)
Tensor
Tanh
(
ctx
Context
)
Tensor
GELU
(
ctx
Context
)
Tensor
GELU
(
ctx
Context
)
Tensor
SILU
(
ctx
Context
)
Tensor
SILU
(
ctx
Context
)
Tensor
Sigmoid
(
ctx
Context
)
Tensor
Reshape
(
ctx
Context
,
shape
...
int
)
Tensor
Reshape
(
ctx
Context
,
shape
...
int
)
Tensor
View
(
ctx
Context
,
offset
int
,
shape
...
int
)
Tensor
View
(
ctx
Context
,
offset
int
,
shape
...
int
)
Tensor
...
@@ -168,6 +170,8 @@ type Tensor interface {
...
@@ -168,6 +170,8 @@ type Tensor interface {
Rows
(
ctx
Context
,
t2
Tensor
)
Tensor
Rows
(
ctx
Context
,
t2
Tensor
)
Tensor
Copy
(
ctx
Context
,
t2
Tensor
)
Tensor
Copy
(
ctx
Context
,
t2
Tensor
)
Tensor
Duplicate
(
ctx
Context
)
Tensor
Duplicate
(
ctx
Context
)
Tensor
TopK
(
ctx
Context
,
k
int
)
Tensor
}
}
// ScaledDotProductAttention implements a fused attention
// ScaledDotProductAttention implements a fused attention
...
...
ml/backend/ggml/ggml.go
View file @
f0c66e6d
...
@@ -884,17 +884,32 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
...
@@ -884,17 +884,32 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
}
}
func
(
t
*
Tensor
)
MulmatID
(
ctx
ml
.
Context
,
t2
,
ids
ml
.
Tensor
)
ml
.
Tensor
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_mul_mat_id
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
t2
.
(
*
Tensor
)
.
t
,
ids
.
(
*
Tensor
)
.
t
),
}
}
func
(
t
*
Tensor
)
LayerNorm
(
ctx
ml
.
Context
,
w
,
b
ml
.
Tensor
,
eps
float32
)
ml
.
Tensor
{
func
(
t
*
Tensor
)
LayerNorm
(
ctx
ml
.
Context
,
w
,
b
ml
.
Tensor
,
eps
float32
)
ml
.
Tensor
{
tt
:=
(
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_norm
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
float
(
eps
))})
.
Mul
(
ctx
,
w
)
tt
:=
C
.
ggml_norm
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
float
(
eps
))
if
b
!=
nil
{
if
w
!=
nil
{
tt
=
tt
.
Add
(
ctx
,
b
)
tt
=
C
.
ggml_mul
(
ctx
.
(
*
Context
)
.
ctx
,
tt
,
w
.
(
*
Tensor
)
.
t
)
if
b
!=
nil
{
tt
=
C
.
ggml_add
(
ctx
.
(
*
Context
)
.
ctx
,
tt
,
b
.
(
*
Tensor
)
.
t
)
}
}
}
return
tt
return
&
Tensor
{
b
:
t
.
b
,
t
:
tt
}
}
}
func
(
t
*
Tensor
)
RMSNorm
(
ctx
ml
.
Context
,
w
ml
.
Tensor
,
eps
float32
)
ml
.
Tensor
{
func
(
t
*
Tensor
)
RMSNorm
(
ctx
ml
.
Context
,
w
ml
.
Tensor
,
eps
float32
)
ml
.
Tensor
{
return
(
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_rms_norm
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
float
(
eps
))})
.
Mul
(
ctx
,
w
)
tt
:=
C
.
ggml_rms_norm
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
float
(
eps
))
if
w
!=
nil
{
tt
=
C
.
ggml_mul
(
ctx
.
(
*
Context
)
.
ctx
,
tt
,
w
.
(
*
Tensor
)
.
t
)
}
return
&
Tensor
{
b
:
t
.
b
,
t
:
tt
}
}
}
func
(
t
*
Tensor
)
Pad
(
ctx
ml
.
Context
,
shape
...
int
)
ml
.
Tensor
{
func
(
t
*
Tensor
)
Pad
(
ctx
ml
.
Context
,
shape
...
int
)
ml
.
Tensor
{
...
@@ -995,6 +1010,13 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
...
@@ -995,6 +1010,13 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
}
}
}
}
func
(
t
*
Tensor
)
Sigmoid
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_sigmoid_inplace
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
),
}
}
func
(
t
*
Tensor
)
Unpad
(
ctx
ml
.
Context
,
shape
...
int
)
ml
.
Tensor
{
func
(
t
*
Tensor
)
Unpad
(
ctx
ml
.
Context
,
shape
...
int
)
ml
.
Tensor
{
if
len
(
shape
)
!=
4
{
if
len
(
shape
)
!=
4
{
panic
(
"expected 4 dimensions"
)
panic
(
"expected 4 dimensions"
)
...
@@ -1158,3 +1180,10 @@ func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
...
@@ -1158,3 +1180,10 @@ func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
t
:
C
.
ggml_dup
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
),
t
:
C
.
ggml_dup
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
),
}
}
}
}
func
(
t
*
Tensor
)
TopK
(
ctx
ml
.
Context
,
k
int
)
ml
.
Tensor
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_top_k
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
int
(
k
)),
}
}
model/models/llama4/model.go
0 → 100644
View file @
f0c66e6d
package
llama4
import
(
"bytes"
"image"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type
Model
struct
{
model
.
Base
model
.
BytePairEncoding
*
VisionModel
`gguf:"v,vision"`
*
Projector
`gguf:"mm"`
*
TextModel
}
type
Projector
struct
{
Linear1
*
nn
.
Linear
`gguf:"linear_1"`
}
func
(
p
*
Projector
)
Forward
(
ctx
ml
.
Context
,
visionOutputs
ml
.
Tensor
)
ml
.
Tensor
{
return
p
.
Linear1
.
Forward
(
ctx
,
visionOutputs
)
}
func
New
(
c
fs
.
Config
)
(
model
.
Model
,
error
)
{
m
:=
Model
{
BytePairEncoding
:
model
.
NewBytePairEncoding
(
c
.
String
(
"tokenizer.ggml.pretokenizer"
,
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
),
&
model
.
Vocabulary
{
Values
:
c
.
Strings
(
"tokenizer.ggml.tokens"
),
Types
:
c
.
Uints
(
"tokenizer.ggml.token_type"
),
Merges
:
c
.
Strings
(
"tokenizer.ggml.merges"
),
BOS
:
int32
(
c
.
Uint
(
"tokenizer.ggml.bos_token_id"
)),
AddBOS
:
c
.
Bool
(
"tokenizer.ggml.add_bos_token"
,
true
),
EOS
:
int32
(
c
.
Uint
(
"tokenizer.ggml.eos_token_id"
)),
AddEOS
:
c
.
Bool
(
"tokenizer.ggml.add_eos_token"
,
false
),
},
),
VisionModel
:
newVisionModel
(
c
),
TextModel
:
newTextModel
(
c
),
}
m
.
Cache
=
kvcache
.
NewWrapperCache
(
// TODO: pretend this is chunked attention for now
kvcache
.
NewSWACache
(
8192
,
m
.
Shift
),
kvcache
.
NewCausalCache
(
m
.
Shift
),
)
return
&
m
,
nil
}
func
(
m
*
Model
)
EncodeMultimodal
(
ctx
ml
.
Context
,
multimodalData
[]
byte
)
(
any
,
error
)
{
if
len
(
m
.
VisionModel
.
Layers
)
<
1
{
return
nil
,
model
.
ErrNoVisionModel
}
img
,
_
,
err
:=
image
.
Decode
(
bytes
.
NewReader
(
multimodalData
))
if
err
!=
nil
{
return
nil
,
err
}
f32s
,
aspectRatio
,
err
:=
m
.
ProcessImage
(
ctx
,
img
)
if
err
!=
nil
{
return
nil
,
err
}
pixelValues
,
err
:=
ctx
.
Input
()
.
FromFloatSlice
(
f32s
,
len
(
f32s
))
if
err
!=
nil
{
return
nil
,
err
}
visionOutputs
:=
m
.
VisionModel
.
Forward
(
ctx
,
pixelValues
)
visionOutputs
=
visionOutputs
.
Reshape
(
ctx
,
visionOutputs
.
Dim
(
0
),
visionOutputs
.
Dim
(
1
)
*
visionOutputs
.
Dim
(
2
)
*
visionOutputs
.
Dim
(
3
))
return
m
.
Projector
.
Forward
(
ctx
,
visionOutputs
),
nil
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
if
err
!=
nil
{
return
nil
,
err
}
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Outputs
,
len
(
batch
.
Outputs
))
if
err
!=
nil
{
return
nil
,
err
}
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
outputs
,
batch
,
m
.
Cache
),
nil
}
func
init
()
{
model
.
Register
(
"llama4"
,
New
)
}
model/models/llama4/model_text.go
0 → 100644
View file @
f0c66e6d
package
llama4
import
(
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
type
TextAttention
struct
{
Query
*
nn
.
Linear
`gguf:"attn_q"`
Key
*
nn
.
Linear
`gguf:"attn_k"`
Value
*
nn
.
Linear
`gguf:"attn_v"`
Output
*
nn
.
Linear
`gguf:"attn_output"`
RopeFactors
ml
.
Tensor
`gguf:"rope_factors"`
}
func
(
sa
*
TextAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
positions
ml
.
Tensor
,
cache
kvcache
.
Cache
,
useRope
bool
,
opts
*
TextOptions
)
ml
.
Tensor
{
batchSize
,
headDim
:=
hiddenStates
.
Dim
(
1
),
cmp
.
Or
(
opts
.
headDim
,
opts
.
hiddenSize
/
opts
.
numHeads
)
query
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenStates
)
key
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenStates
)
value
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenStates
)
query
=
query
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
key
=
key
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
value
=
value
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
if
useRope
{
query
=
query
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
uint32
(
opts
.
ropeDim
),
uint32
(
0
),
opts
.
ropeBase
,
opts
.
ropeScale
)
key
=
key
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
uint32
(
opts
.
ropeDim
),
uint32
(
0
),
opts
.
ropeBase
,
opts
.
ropeScale
)
if
opts
.
useQKNorm
{
query
=
query
.
RMSNorm
(
ctx
,
nil
,
opts
.
eps
)
key
=
key
.
RMSNorm
(
ctx
,
nil
,
opts
.
eps
)
}
}
attention
:=
nn
.
Attention
(
ctx
,
query
,
key
,
value
,
1.
/
math
.
Sqrt
(
float64
(
headDim
)),
cache
)
attention
=
attention
.
Reshape
(
ctx
,
opts
.
hiddenSize
,
batchSize
)
return
sa
.
Output
.
Forward
(
ctx
,
attention
)
}
type
TextMLP
struct
{
Gate
*
nn
.
Linear
`gguf:"ffn_gate"`
Up
*
nn
.
Linear
`gguf:"ffn_up"`
Down
*
nn
.
Linear
`gguf:"ffn_down"`
}
func
(
mlp
*
TextMLP
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
}
type
TextExperts
struct
{
Gate
ml
.
Tensor
`gguf:"ffn_gate_exps.weight"`
Up
ml
.
Tensor
`gguf:"ffn_up_exps.weight"`
Down
ml
.
Tensor
`gguf:"ffn_down_exps.weight"`
}
func
(
e
*
TextExperts
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
routerLogits
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
experts
:=
routerLogits
.
TopK
(
ctx
,
opts
.
numExpertsUsed
)
scores
:=
routerLogits
.
Sigmoid
(
ctx
)
.
Reshape
(
ctx
,
1
,
opts
.
numExperts
,
hiddenStates
.
Dim
(
1
))
.
Rows
(
ctx
,
experts
)
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
hiddenStates
.
Dim
(
0
),
1
,
hiddenStates
.
Dim
(
1
))
hiddenStates
=
hiddenStates
.
Repeat
(
ctx
,
1
,
opts
.
numExpertsUsed
)
hiddenStates
=
hiddenStates
.
Mul
(
ctx
,
scores
)
upStates
:=
e
.
Up
.
MulmatID
(
ctx
,
hiddenStates
,
experts
)
gateStates
:=
e
.
Gate
.
MulmatID
(
ctx
,
hiddenStates
,
experts
)
downStates
:=
e
.
Down
.
MulmatID
(
ctx
,
upStates
.
Mul
(
ctx
,
gateStates
.
SILU
(
ctx
)),
experts
)
nextStates
:=
downStates
.
View
(
ctx
,
0
,
hiddenStates
.
Dim
(
0
),
downStates
.
Stride
(
2
),
hiddenStates
.
Dim
(
2
))
for
i
:=
1
;
i
<
opts
.
numExpertsUsed
;
i
++
{
nextStates
.
Add
(
ctx
,
downStates
.
View
(
ctx
,
i
*
downStates
.
Stride
(
1
),
hiddenStates
.
Dim
(
0
),
downStates
.
Stride
(
2
),
hiddenStates
.
Dim
(
2
)))
}
return
nextStates
}
// TextSharedExpert is TextMLP with different names
type
TextSharedExpert
struct
{
Gate
*
nn
.
Linear
`gguf:"ffn_gate_shexp"`
Up
*
nn
.
Linear
`gguf:"ffn_up_shexp"`
Down
*
nn
.
Linear
`gguf:"ffn_down_shexp"`
}
func
(
mlp
*
TextSharedExpert
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
}
type
TextMOE
struct
{
Router
*
nn
.
Linear
`gguf:"ffn_gate_inp"`
Experts
*
TextExperts
SharedExpert
*
TextSharedExpert
}
func
(
moe
*
TextMOE
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
hiddenDim
,
sequenceLength
,
batchSize
:=
hiddenStates
.
Dim
(
0
),
hiddenStates
.
Dim
(
1
),
hiddenStates
.
Dim
(
2
)
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
hiddenDim
,
sequenceLength
*
batchSize
)
routerLogits
:=
moe
.
Router
.
Forward
(
ctx
,
hiddenStates
)
sharedStates
:=
moe
.
SharedExpert
.
Forward
(
ctx
,
hiddenStates
,
opts
)
routedStates
:=
moe
.
Experts
.
Forward
(
ctx
,
hiddenStates
,
routerLogits
,
opts
)
return
sharedStates
.
Add
(
ctx
,
routedStates
)
}
type
TextFeedForward
interface
{
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
}
type
TextLayer
struct
{
AttentionNorm
*
nn
.
LayerNorm
`gguf:"attn_norm"`
Attention
*
TextAttention
FFNNorm
*
nn
.
LayerNorm
`gguf:"ffn_norm"`
FeedForward
TextFeedForward
}
func
(
d
*
TextLayer
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
positions
,
outputs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
useRope
bool
,
opts
*
TextOptions
)
ml
.
Tensor
{
residual
:=
hiddenStates
// self attention
hiddenStates
=
d
.
AttentionNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
hiddenStates
=
d
.
Attention
.
Forward
(
ctx
,
hiddenStates
,
positions
,
cache
,
useRope
,
opts
)
if
outputs
!=
nil
{
hiddenStates
=
hiddenStates
.
Rows
(
ctx
,
outputs
)
residual
=
residual
.
Rows
(
ctx
,
outputs
)
}
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
residual
)
residual
=
hiddenStates
hiddenStates
=
d
.
FFNNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
hiddenStates
=
d
.
FeedForward
.
Forward
(
ctx
,
hiddenStates
,
opts
)
return
residual
.
Add
(
ctx
,
hiddenStates
)
}
type
TextOptions
struct
{
hiddenSize
int
numHeads
,
numKVHeads
,
headDim
int
numExperts
,
numExpertsUsed
int
ropeDim
int
ropeBase
,
ropeScale
float32
eps
float32
interleaveLayerStep
int
useQKNorm
bool
}
type
TextModel
struct
{
Layers
[]
TextLayer
`gguf:"blk"`
TokenEmbedding
*
nn
.
Embedding
`gguf:"token_embd"`
OutputNorm
*
nn
.
LayerNorm
`gguf:"output_norm"`
Output
*
nn
.
Linear
`gguf:"output,alt:token_embd"`
*
TextOptions
}
func
newTextModel
(
c
fs
.
Config
)
*
TextModel
{
layers
:=
make
([]
TextLayer
,
c
.
Uint
(
"block_count"
))
interleaveLayerStep
:=
c
.
Uint
(
"interleave_moe_layer_step"
,
1
)
for
i
:=
range
layers
{
if
(
i
+
1
)
%
int
(
interleaveLayerStep
)
==
0
{
layers
[
i
]
=
TextLayer
{
FeedForward
:
&
TextMOE
{}}
}
else
{
layers
[
i
]
=
TextLayer
{
FeedForward
:
&
TextMLP
{}}
}
}
return
&
TextModel
{
Layers
:
layers
,
TextOptions
:
&
TextOptions
{
hiddenSize
:
int
(
c
.
Uint
(
"embedding_length"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
headDim
:
int
(
c
.
Uint
(
"attention.head_dim"
,
128
)),
numExperts
:
int
(
c
.
Uint
(
"expert_count"
)),
numExpertsUsed
:
int
(
c
.
Uint
(
"expert_used_count"
)),
ropeDim
:
int
(
c
.
Uint
(
"rope.dimension_count"
)),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
interleaveLayerStep
:
int
(
c
.
Uint
(
"interleave_moe_layer_step"
,
1
)),
useQKNorm
:
c
.
Bool
(
"use_qk_norm"
,
true
),
},
}
}
func
(
m
*
TextModel
)
Forward
(
ctx
ml
.
Context
,
inputs
,
positions
,
outputs
ml
.
Tensor
,
batch
input
.
Batch
,
cache
kvcache
.
Cache
)
ml
.
Tensor
{
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
inputs
)
for
i
,
layer
:=
range
m
.
Layers
{
cache
.
SetLayer
(
i
)
wc
:=
cache
.
(
*
kvcache
.
WrapperCache
)
wc
.
SetLayerType
(
1
)
useChunkedAttention
:=
(
i
+
1
)
%
4
!=
0
if
useChunkedAttention
{
wc
.
SetLayerType
(
0
)
}
var
lastLayerOutputs
ml
.
Tensor
if
i
==
len
(
m
.
Layers
)
-
1
{
lastLayerOutputs
=
outputs
}
hiddenStates
=
layer
.
Forward
(
ctx
,
hiddenStates
,
positions
,
lastLayerOutputs
,
cache
,
useChunkedAttention
,
m
.
TextOptions
)
}
hiddenStates
=
m
.
OutputNorm
.
Forward
(
ctx
,
hiddenStates
,
m
.
eps
)
return
m
.
Output
.
Forward
(
ctx
,
hiddenStates
)
}
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
return
key
.
RoPE
(
ctx
,
shift
,
m
.
Layers
[
layer
]
.
Attention
.
RopeFactors
,
uint32
(
0
),
uint32
(
m
.
ropeDim
),
m
.
ropeBase
,
m
.
ropeScale
),
nil
}
model/models/llama4/model_vision.go
0 → 100644
View file @
f0c66e6d
package
llama4
import
(
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type
VisionAttention
struct
{
Query
*
nn
.
Linear
`gguf:"attn_q"`
Key
*
nn
.
Linear
`gguf:"attn_k"`
Value
*
nn
.
Linear
`gguf:"attn_v"`
Output
*
nn
.
Linear
`gguf:"attn_output"`
}
// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor.
// This is equivalent to the Pytorch implmentation using half rotations:
//
// cos, sin = torch.cos(freqs), torch.sin(freqs)
// cos = cos.unsqueeze(-1)
// sin = sin.unsqueeze(-1)
// t = t.reshape(*t.shape[:-1], -1, 2)
// t_out = (t * cos) + (_rotate_half(t) * sin)
// t_out = t_out.flatten(3)
//
// Which is equivalent to the Pytorch implementation using complex numbers:
//
// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2))
// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:]
// freqs_ci = freqs_ci.to(t_.device)
// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3)
//
// Due to the 1) the dimensional and 2) the datatype limitations of current backends,
// we need to use a different approach to achieve the same result.
func
applyVisionRotaryEmbedding
(
ctx
ml
.
Context
,
t
,
cos
,
sin
ml
.
Tensor
)
ml
.
Tensor
{
width
,
height
,
channels
,
tiles
:=
t
.
Dim
(
0
),
t
.
Dim
(
1
),
t
.
Dim
(
2
),
t
.
Dim
(
3
)
t
=
t
.
Reshape
(
ctx
,
2
,
t
.
Dim
(
0
)
/
2
,
t
.
Dim
(
1
)
*
t
.
Dim
(
2
)
*
t
.
Dim
(
3
))
// t1 = t[..., 0::2]
t1
:=
t
.
View
(
ctx
,
0
,
1
,
t
.
Stride
(
1
),
t
.
Dim
(
1
),
t
.
Stride
(
2
),
t
.
Dim
(
2
))
.
Contiguous
(
ctx
)
t1
=
t1
.
Reshape
(
ctx
,
width
/
2
,
height
,
channels
,
tiles
)
// t2 = t[..., 1::2]
t2
:=
t
.
View
(
ctx
,
t
.
Stride
(
0
),
1
,
t
.
Stride
(
1
),
t
.
Dim
(
1
),
t
.
Stride
(
2
),
t
.
Dim
(
2
))
.
Contiguous
(
ctx
)
t2
=
t2
.
Reshape
(
ctx
,
width
/
2
,
height
,
channels
,
tiles
)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut
:=
t1
.
Mul
(
ctx
,
cos
)
.
Concat
(
ctx
,
t2
.
Mul
(
ctx
,
cos
),
0
)
cosOut
=
cosOut
.
Reshape
(
ctx
,
cosOut
.
Dim
(
0
)
/
2
,
2
,
cosOut
.
Dim
(
1
)
*
cosOut
.
Dim
(
2
)
*
cosOut
.
Dim
(
3
))
cosOut
=
cosOut
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
cosOut
=
cosOut
.
Reshape
(
ctx
,
width
,
height
,
channels
,
tiles
)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut
:=
t2
.
Neg
(
ctx
)
.
Mul
(
ctx
,
sin
)
.
Concat
(
ctx
,
t1
.
Mul
(
ctx
,
sin
),
0
)
sinOut
=
sinOut
.
Reshape
(
ctx
,
sinOut
.
Dim
(
0
)
/
2
,
2
,
sinOut
.
Dim
(
1
)
*
sinOut
.
Dim
(
2
)
*
sinOut
.
Dim
(
3
))
sinOut
=
sinOut
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
sinOut
=
sinOut
.
Reshape
(
ctx
,
width
,
height
,
channels
,
tiles
)
return
cosOut
.
Add
(
ctx
,
sinOut
)
}
func
(
sa
*
VisionAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
cos
,
sin
ml
.
Tensor
,
opts
*
VisionOptions
)
ml
.
Tensor
{
headDim
:=
opts
.
hiddenSize
/
opts
.
numHeads
query
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
key
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
value
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
query
=
query
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
query
.
Dim
(
1
),
query
.
Dim
(
2
))
key
=
key
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
key
.
Dim
(
1
),
key
.
Dim
(
2
))
value
=
value
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
value
.
Dim
(
1
),
value
.
Dim
(
2
))
query
=
applyVisionRotaryEmbedding
(
ctx
,
query
,
cos
,
sin
)
key
=
applyVisionRotaryEmbedding
(
ctx
,
key
,
cos
,
sin
)
attention
:=
nn
.
Attention
(
ctx
,
query
,
key
,
value
,
1.
/
math
.
Sqrt
(
float64
(
headDim
)),
nil
)
attention
=
attention
.
Reshape
(
ctx
,
opts
.
hiddenSize
,
attention
.
Dim
(
2
),
attention
.
Dim
(
3
))
return
sa
.
Output
.
Forward
(
ctx
,
attention
)
}
type
VisionMLP
struct
{
FC1
*
nn
.
Linear
`gguf:"fc1"`
FC2
*
nn
.
Linear
`gguf:"fc2"`
}
func
(
mlp
*
VisionMLP
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
VisionOptions
)
ml
.
Tensor
{
hiddenStates
=
mlp
.
FC1
.
Forward
(
ctx
,
hiddenStates
)
.
GELU
(
ctx
)
hiddenStates
=
mlp
.
FC2
.
Forward
(
ctx
,
hiddenStates
)
return
hiddenStates
}
type
VisionLayer
struct
{
InputLayerNorm
*
nn
.
LayerNorm
`gguf:"attn_norm"`
*
VisionAttention
PostAttentionNorm
*
nn
.
LayerNorm
`gguf:"ffn_norm"`
*
VisionMLP
`gguf:"mlp"`
}
func
(
e
*
VisionLayer
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
cos
,
sin
ml
.
Tensor
,
opts
*
VisionOptions
)
ml
.
Tensor
{
residual
:=
hiddenStates
// self attention
hiddenStates
=
e
.
InputLayerNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
hiddenStates
=
e
.
VisionAttention
.
Forward
(
ctx
,
hiddenStates
,
cos
,
sin
,
opts
)
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
residual
)
// MLP
residual
=
hiddenStates
hiddenStates
=
e
.
PostAttentionNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
hiddenStates
=
e
.
VisionMLP
.
Forward
(
ctx
,
hiddenStates
,
opts
)
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
residual
)
return
hiddenStates
}
type
VisionAdapter
struct
{
FC1
*
nn
.
Linear
`gguf:"mlp.fc1"`
FC2
*
nn
.
Linear
`gguf:"mlp.fc2"`
}
func
(
a
*
VisionAdapter
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
VisionOptions
)
ml
.
Tensor
{
patches
:=
hiddenStates
.
Dim
(
1
)
patchSize
:=
int
(
math
.
Sqrt
(
float64
(
patches
)))
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
hiddenStates
.
Dim
(
0
),
patchSize
,
patchSize
,
hiddenStates
.
Dim
(
2
))
channels
,
width
,
height
,
tiles
:=
hiddenStates
.
Dim
(
0
),
hiddenStates
.
Dim
(
1
),
hiddenStates
.
Dim
(
2
),
hiddenStates
.
Dim
(
3
)
channels
,
width
=
int
(
float32
(
channels
)
/
opts
.
pixelShuffleRatio
),
int
(
float32
(
width
)
*
opts
.
pixelShuffleRatio
)
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
channels
,
width
,
height
,
tiles
)
hiddenStates
=
hiddenStates
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
.
Contiguous
(
ctx
)
channels
,
height
=
int
(
float32
(
channels
)
/
opts
.
pixelShuffleRatio
),
int
(
float32
(
height
)
*
opts
.
pixelShuffleRatio
)
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
channels
,
width
,
height
,
tiles
)
hiddenStates
=
hiddenStates
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
.
Contiguous
(
ctx
)
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
channels
,
width
*
height
,
tiles
)
hiddenStates
=
a
.
FC1
.
Forward
(
ctx
,
hiddenStates
)
.
GELU
(
ctx
)
hiddenStates
=
a
.
FC2
.
Forward
(
ctx
,
hiddenStates
)
.
GELU
(
ctx
)
return
hiddenStates
}
type
VisionOptions
struct
{
hiddenSize
,
numHeads
int
imageSize
,
patchSize
int
ropeTheta
float32
eps
float32
pixelShuffleRatio
float32
}
type
PatchEmbedding
struct
{
*
nn
.
Linear
}
func
(
p
*
PatchEmbedding
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
VisionOptions
)
ml
.
Tensor
{
kernel
:=
ctx
.
Input
()
.
Empty
(
ml
.
DTypeF32
,
opts
.
patchSize
,
opts
.
patchSize
,
hiddenStates
.
Dim
(
2
))
hiddenStates
=
kernel
.
IM2Col
(
ctx
,
hiddenStates
,
opts
.
patchSize
,
opts
.
patchSize
,
0
,
0
,
1
,
1
)
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
hiddenStates
.
Dim
(
0
),
hiddenStates
.
Dim
(
1
)
*
hiddenStates
.
Dim
(
2
),
hiddenStates
.
Dim
(
3
))
return
p
.
Linear
.
Forward
(
ctx
,
hiddenStates
)
}
type
VisionModel
struct
{
Layers
[]
VisionLayer
`gguf:"blk"`
*
PatchEmbedding
`gguf:"patch_embedding"`
ClassEmbedding
ml
.
Tensor
`gguf:"class_embedding"`
PositionalEmbedding
ml
.
Tensor
`gguf:"positional_embedding_vlm"`
LayerNormPre
*
nn
.
LayerNorm
`gguf:"layernorm_pre"`
LayerNormPost
*
nn
.
LayerNorm
`gguf:"layernorm_post"`
*
VisionAdapter
`gguf:"vision_adapter"`
*
VisionOptions
}
func
newVisionModel
(
c
fs
.
Config
)
*
VisionModel
{
return
&
VisionModel
{
Layers
:
make
([]
VisionLayer
,
c
.
Uint
(
"vision.block_count"
)),
VisionOptions
:
&
VisionOptions
{
hiddenSize
:
int
(
c
.
Uint
(
"vision.embedding_length"
)),
numHeads
:
int
(
c
.
Uint
(
"vision.attention.head_count"
)),
imageSize
:
int
(
c
.
Uint
(
"vision.image_size"
)),
patchSize
:
int
(
c
.
Uint
(
"vision.patch_size"
)),
ropeTheta
:
float32
(
c
.
Float
(
"vision.rope.freq_base"
)),
eps
:
c
.
Float
(
"vision.layer_norm_epsilon"
),
pixelShuffleRatio
:
float32
(
c
.
Float
(
"vision.pixel_shuffle_ratio"
)),
},
}
}
func
(
m
*
VisionModel
)
Forward
(
ctx
ml
.
Context
,
pixelValues
ml
.
Tensor
)
ml
.
Tensor
{
hiddenStates
:=
m
.
PatchEmbedding
.
Forward
(
ctx
,
pixelValues
,
m
.
VisionOptions
)
hiddenStates
=
hiddenStates
.
Concat
(
ctx
,
m
.
ClassEmbedding
.
Repeat
(
ctx
,
2
,
hiddenStates
.
Dim
(
2
)),
1
)
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
m
.
PositionalEmbedding
)
hiddenStates
=
m
.
LayerNormPre
.
Forward
(
ctx
,
hiddenStates
,
m
.
eps
)
cos
,
sin
:=
m
.
rotaryEmbedding
(
ctx
)
for
_
,
layer
:=
range
m
.
Layers
{
hiddenStates
=
layer
.
Forward
(
ctx
,
hiddenStates
,
cos
,
sin
,
m
.
VisionOptions
)
}
hiddenStates
=
m
.
LayerNormPost
.
Forward
(
ctx
,
hiddenStates
,
m
.
eps
)
hiddenStates
=
hiddenStates
.
Unpad
(
ctx
,
0
,
1
,
0
,
0
)
hiddenStates
=
m
.
VisionAdapter
.
Forward
(
ctx
,
hiddenStates
,
m
.
VisionOptions
)
return
hiddenStates
}
// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function
// which in turn mimics Python's // operator.
func
floorDiv
[
T
int
|
int16
|
int32
|
int64
|
uint
|
uint16
|
uint32
|
uint64
](
a
,
b
T
)
T
{
if
b
==
0
{
panic
(
"division by zero"
)
}
if
(
a
>=
0
&&
b
>
0
)
||
(
a
<=
0
&&
b
<
0
)
||
a
%
b
==
0
{
return
a
/
b
}
return
a
/
b
-
1
}
func
(
m
*
VisionModel
)
rotaryEmbedding
(
ctx
ml
.
Context
)
(
ml
.
Tensor
,
ml
.
Tensor
)
{
patchesPerSide
:=
m
.
imageSize
/
m
.
patchSize
numPatches
:=
patchesPerSide
*
patchesPerSide
+
1
headDim
:=
m
.
hiddenSize
/
m
.
numHeads
freqDim
:=
headDim
/
2
freqs
:=
make
([]
float32
,
numPatches
*
freqDim
)
for
i
:=
range
numPatches
-
1
{
for
j
:=
0
;
j
<
freqDim
;
j
+=
2
{
positionX
:=
i
*
freqDim
/
2
+
j
/
2
positionY
:=
(
i
+
numPatches
)
*
freqDim
/
2
+
j
/
2
ropeFreq
:=
math
.
Pow
(
float64
(
m
.
ropeTheta
),
float64
(
j
)
*
2
/
float64
(
headDim
))
freqs
[
positionX
]
=
float32
(
float64
(
1
+
i
-
floorDiv
(
i
,
patchesPerSide
)
*
patchesPerSide
)
/
ropeFreq
)
freqs
[
positionY
]
=
float32
(
float64
(
1
+
floorDiv
(
i
,
patchesPerSide
))
/
ropeFreq
)
}
}
ropeFreqs
,
err
:=
ctx
.
Input
()
.
FromFloatSlice
(
freqs
,
freqDim
/
2
,
numPatches
,
2
)
if
err
!=
nil
{
panic
(
err
)
}
ropeFreqs
=
ropeFreqs
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
.
Contiguous
(
ctx
)
ropeFreqs
=
ropeFreqs
.
Reshape
(
ctx
,
freqDim
,
1
,
numPatches
)
return
ropeFreqs
.
Cos
(
ctx
),
ropeFreqs
.
Sin
(
ctx
)
}
model/models/models.go
View file @
f0c66e6d
...
@@ -4,6 +4,7 @@ import (
...
@@ -4,6 +4,7 @@ import (
_
"github.com/ollama/ollama/model/models/gemma2"
_
"github.com/ollama/ollama/model/models/gemma2"
_
"github.com/ollama/ollama/model/models/gemma3"
_
"github.com/ollama/ollama/model/models/gemma3"
_
"github.com/ollama/ollama/model/models/llama"
_
"github.com/ollama/ollama/model/models/llama"
_
"github.com/ollama/ollama/model/models/llama4"
_
"github.com/ollama/ollama/model/models/mistral3"
_
"github.com/ollama/ollama/model/models/mistral3"
_
"github.com/ollama/ollama/model/models/mllama"
_
"github.com/ollama/ollama/model/models/mllama"
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment