Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
63a39406
Commit
63a39406
authored
Mar 11, 2025
by
Michael Yang
Browse files
use 2d pooling
parent
ab39e08e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
25 deletions
+36
-25
convert/convert_gemma3.go
convert/convert_gemma3.go
+14
-9
ml/backend.go
ml/backend.go
+1
-1
ml/backend/ggml/ggml.go
ml/backend/ggml/ggml.go
+3
-3
model/models/gemma3/model.go
model/models/gemma3/model.go
+18
-12
No files found.
convert/convert_gemma3.go
View file @
63a39406
...
@@ -26,15 +26,16 @@ type gemma3Model struct {
...
@@ -26,15 +26,16 @@ type gemma3Model struct {
NumChannels
uint32
`json:"num_channels"`
// num_channels 3
NumChannels
uint32
`json:"num_channels"`
// num_channels 3
PatchSize
uint32
`json:"patch_size"`
// patch_size 14
PatchSize
uint32
`json:"patch_size"`
// patch_size 14
}
`json:"vision_config"`
}
`json:"vision_config"`
MaxPositionEmbeddings
uint32
`json:"max_position_embeddings"`
MaxPositionEmbeddings
uint32
`json:"max_position_embeddings"`
NumAttentionHeads
uint32
`json:"num_attention_heads"`
NumAttentionHeads
uint32
`json:"num_attention_heads"`
NumKeyValueHeads
uint32
`json:"num_key_value_heads"`
NumKeyValueHeads
uint32
`json:"num_key_value_heads"`
RMSNormEPS
float32
`json:"rms_norm_eps"`
RMSNormEPS
float32
`json:"rms_norm_eps"`
HeadDim
uint32
`json:"head_dim"`
HeadDim
uint32
`json:"head_dim"`
FinalLogitSoftcap
float32
`json:"final_logit_softcapping"`
FinalLogitSoftcap
float32
`json:"final_logit_softcapping"`
RopeLocalTheta
float32
`json:"rope_local_base_freq"`
RopeLocalTheta
float32
`json:"rope_local_base_freq"`
RopeGlobalTheta
float32
`json:"rope_global_base_freq"`
RopeGlobalTheta
float32
`json:"rope_global_base_freq"`
SlidingWindow
uint32
`json:"sliding_window"`
SlidingWindow
uint32
`json:"sliding_window"`
MultiModalTokensPerImage
uint32
`json:"mm_tokens_per_image"`
}
}
const
(
const
(
...
@@ -102,6 +103,10 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
...
@@ -102,6 +103,10 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv
[
"gemma3.attention.value_length"
]
=
cmp
.
Or
(
p
.
TextModel
.
HeadDim
,
256
)
kv
[
"gemma3.attention.value_length"
]
=
cmp
.
Or
(
p
.
TextModel
.
HeadDim
,
256
)
}
}
if
p
.
MultiModalTokensPerImage
>
0
{
kv
[
"gemma3.mm.tokens_per_image"
]
=
p
.
MultiModalTokensPerImage
}
return
kv
return
kv
}
}
...
...
ml/backend.go
View file @
63a39406
...
@@ -135,7 +135,7 @@ type Tensor interface {
...
@@ -135,7 +135,7 @@ type Tensor interface {
RMSNorm
(
ctx
Context
,
weight
Tensor
,
eps
float32
)
Tensor
RMSNorm
(
ctx
Context
,
weight
Tensor
,
eps
float32
)
Tensor
Scale
(
ctx
Context
,
s
float64
)
Tensor
Scale
(
ctx
Context
,
s
float64
)
Tensor
AvgPool
1
D
(
ctx
Context
,
k
,
s
,
p
int
)
Tensor
AvgPool
2
D
(
ctx
Context
,
k
,
s
int
,
p
float32
)
Tensor
Conv2D
(
ctx
Context
,
weight
Tensor
,
s0
,
s1
,
p0
,
p1
,
d0
,
d1
int
)
Tensor
Conv2D
(
ctx
Context
,
weight
Tensor
,
s0
,
s1
,
p0
,
p1
,
d0
,
d1
int
)
Tensor
RoPE
(
ctx
Context
,
positionIDs
,
ropeFactors
Tensor
,
dim
,
ropeType
uint32
,
base
,
scale
float32
)
Tensor
RoPE
(
ctx
Context
,
positionIDs
,
ropeFactors
Tensor
,
dim
,
ropeType
uint32
,
base
,
scale
float32
)
Tensor
...
...
ml/backend/ggml/ggml.go
View file @
63a39406
...
@@ -247,7 +247,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
...
@@ -247,7 +247,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
createTensor
(
tensor
{
source
:
t
},
output
.
bts
)
createTensor
(
tensor
{
source
:
t
},
output
.
bts
)
case
strings
.
HasPrefix
(
t
.
Name
,
"v."
)
||
strings
.
HasPrefix
(
t
.
Name
,
"mm."
)
:
case
strings
.
HasPrefix
(
t
.
Name
,
"v."
)
||
strings
.
HasPrefix
(
t
.
Name
,
"mm."
)
:
// TODO: assign vision tensors to the gpu if possible
// TODO: assign vision tensors to the gpu if possible
createTensor
(
tensor
{
source
:
t
},
in
put
.
bts
)
createTensor
(
tensor
{
source
:
t
},
out
put
.
bts
)
case
contains
(
t
.
Name
,
"rope_freqs"
,
"rope_factors_long"
,
"rope_factors_short"
)
:
case
contains
(
t
.
Name
,
"rope_freqs"
,
"rope_factors_long"
,
"rope_factors_short"
)
:
// these tensors should be repeated per layer
// these tensors should be repeated per layer
for
i
,
layer
:=
range
layers
{
for
i
,
layer
:=
range
layers
{
...
@@ -952,10 +952,10 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
...
@@ -952,10 +952,10 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
}
}
func
(
t
*
Tensor
)
AvgPool
1
D
(
ctx
ml
.
Context
,
k
,
s
,
p
int
)
ml
.
Tensor
{
func
(
t
*
Tensor
)
AvgPool
2
D
(
ctx
ml
.
Context
,
k
,
s
int
,
p
float32
)
ml
.
Tensor
{
return
&
Tensor
{
return
&
Tensor
{
b
:
t
.
b
,
b
:
t
.
b
,
t
:
C
.
ggml_pool_
1
d
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
GGML_OP_POOL_AVG
,
C
.
int
(
k
),
C
.
int
(
s
),
C
.
in
t
(
p
)),
t
:
C
.
ggml_pool_
2
d
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
GGML_OP_POOL_AVG
,
C
.
int
(
k
),
C
.
int
(
k
),
C
.
int
(
s
),
C
.
int
(
s
),
C
.
float
(
p
),
C
.
floa
t
(
p
)),
}
}
}
}
...
...
model/models/gemma3/model.go
View file @
63a39406
...
@@ -5,6 +5,7 @@ import (
...
@@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/binary"
"hash/fnv"
"hash/fnv"
"image"
"image"
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
...
@@ -30,9 +31,21 @@ var _ model.MultimodalProcessor = (*Model)(nil)
...
@@ -30,9 +31,21 @@ var _ model.MultimodalProcessor = (*Model)(nil)
type
MultiModalProjector
struct
{
type
MultiModalProjector
struct
{
SoftEmbNorm
*
nn
.
RMSNorm
`gguf:"mm_soft_emb_norm"`
SoftEmbNorm
*
nn
.
RMSNorm
`gguf:"mm_soft_emb_norm"`
InputProjection
*
nn
.
Linear
`gguf:"mm_input_projection"`
InputProjection
*
nn
.
Linear
`gguf:"mm_input_projection"`
tokensPerImage
int
}
}
func
(
p
*
MultiModalProjector
)
Forward
(
ctx
ml
.
Context
,
visionOutputs
ml
.
Tensor
,
eps
float32
)
ml
.
Tensor
{
func
(
p
*
MultiModalProjector
)
Forward
(
ctx
ml
.
Context
,
visionOutputs
ml
.
Tensor
,
imageSize
,
patchSize
int
,
eps
float32
)
ml
.
Tensor
{
l
:=
visionOutputs
.
Dim
(
0
)
visionOutputs
=
visionOutputs
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
patchesPerImage
:=
imageSize
/
patchSize
visionOutputs
=
visionOutputs
.
Reshape
(
ctx
,
patchesPerImage
,
patchesPerImage
,
l
)
kernelSize
:=
patchesPerImage
/
int
(
math
.
Sqrt
(
float64
(
p
.
tokensPerImage
)))
visionOutputs
=
visionOutputs
.
AvgPool2D
(
ctx
,
kernelSize
,
kernelSize
,
0
)
visionOutputs
=
visionOutputs
.
Reshape
(
ctx
,
visionOutputs
.
Dim
(
0
)
*
visionOutputs
.
Dim
(
1
),
l
)
visionOutputs
=
visionOutputs
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
visionOutputs
=
p
.
SoftEmbNorm
.
Forward
(
ctx
,
visionOutputs
,
eps
)
visionOutputs
=
p
.
SoftEmbNorm
.
Forward
(
ctx
,
visionOutputs
,
eps
)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
...
@@ -59,6 +72,9 @@ func New(c ml.Config) (model.Model, error) {
...
@@ -59,6 +72,9 @@ func New(c ml.Config) (model.Model, error) {
ImageProcessor
:
newImageProcessor
(
c
),
ImageProcessor
:
newImageProcessor
(
c
),
VisionModel
:
newVisionModel
(
c
),
VisionModel
:
newVisionModel
(
c
),
TextModel
:
newTextModel
(
c
),
TextModel
:
newTextModel
(
c
),
MultiModalProjector
:
&
MultiModalProjector
{
tokensPerImage
:
int
(
c
.
Uint
(
"mm_tokens_per_image"
,
256
)),
},
}
}
slidingWindowLen
:=
int32
(
c
.
Uint
(
"attention.sliding_window"
))
slidingWindowLen
:=
int32
(
c
.
Uint
(
"attention.sliding_window"
))
...
@@ -88,17 +104,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
...
@@ -88,17 +104,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
}
}
visionOutputs
:=
m
.
VisionModel
.
Forward
(
ctx
,
pixelValues
)
visionOutputs
:=
m
.
VisionModel
.
Forward
(
ctx
,
pixelValues
)
visionOutputs
=
visionOutputs
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
visionOutputs
=
m
.
MultiModalProjector
.
Forward
(
ctx
,
visionOutputs
,
m
.
imageSize
,
m
.
patchSize
,
m
.
VisionModel
.
eps
)
patchesPerImage
:=
m
.
ImageProcessor
.
imageSize
/
m
.
ImageProcessor
.
patchSize
// TODO (jmorganca): read this from the model config
// it should instead be math.Sqrt(tokens per image)
tokensPerSide
:=
8
kernelSize
:=
patchesPerImage
/
tokensPerSide
visionOutputs
=
visionOutputs
.
AvgPool1D
(
ctx
,
kernelSize
,
kernelSize
,
0
)
visionOutputs
=
visionOutputs
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
visionOutputs
=
m
.
MultiModalProjector
.
Forward
(
ctx
,
visionOutputs
,
m
.
VisionModel
.
eps
)
return
visionOutputs
,
nil
return
visionOutputs
,
nil
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment