Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
aee28501
Unverified
Commit
aee28501
authored
Mar 11, 2025
by
Michael Yang
Committed by
GitHub
Mar 11, 2025
Browse files
Merge pull request #9661 from ollama/gemma
engine: add gemma support
parents
4dcf8016
83f0ec82
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
588 additions
and
14 deletions
+588
-14
model/models/gemma3/model_vision.go
model/models/gemma3/model_vision.go
+127
-0
model/models/gemma3/process_image.go
model/models/gemma3/process_image.go
+58
-0
model/models/llama/model.go
model/models/llama/model.go
+4
-3
model/models/mllama/model_text.go
model/models/mllama/model_text.go
+5
-3
model/models/mllama/process_image.go
model/models/mllama/process_image.go
+0
-2
model/models/models.go
model/models/models.go
+2
-0
model/process_text.go
model/process_text.go
+17
-5
model/process_text_spm.go
model/process_text_spm.go
+246
-0
model/process_text_spm_test.go
model/process_text_spm_test.go
+118
-0
model/testdata/gemma2/tokenizer.model
model/testdata/gemma2/tokenizer.model
+0
-0
server/prompt.go
server/prompt.go
+11
-1
No files found.
model/models/gemma3/model_vision.go
0 → 100644
View file @
aee28501
package
gemma3
import
(
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var
batchSize
int
=
1
type
VisionSelfAttention
struct
{
Query
*
nn
.
Linear
`gguf:"attn_q"`
Key
*
nn
.
Linear
`gguf:"attn_k"`
Value
*
nn
.
Linear
`gguf:"attn_v"`
Output
*
nn
.
Linear
`gguf:"attn_output"`
}
func
(
sa
*
VisionSelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
VisionModelOptions
)
ml
.
Tensor
{
headDim
:=
opts
.
hiddenSize
/
opts
.
numHeads
query
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
key
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
value
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
query
=
query
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
query
.
Dim
(
1
),
batchSize
)
key
=
key
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
key
.
Dim
(
1
),
batchSize
)
value
=
value
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
value
.
Dim
(
1
),
batchSize
)
attention
:=
nn
.
Attention
(
ctx
,
query
,
key
,
value
,
1.0
/
math
.
Sqrt
(
float64
(
headDim
)),
nil
)
attention
=
attention
.
Reshape
(
ctx
,
opts
.
hiddenSize
,
attention
.
Dim
(
2
),
batchSize
)
hiddenState
=
sa
.
Output
.
Forward
(
ctx
,
attention
)
return
hiddenState
}
type
VisionMLP
struct
{
FC1
*
nn
.
Linear
`gguf:"fc1"`
FC2
*
nn
.
Linear
`gguf:"fc2"`
}
func
(
mlp
*
VisionMLP
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
VisionModelOptions
)
ml
.
Tensor
{
hiddenState
=
mlp
.
FC1
.
Forward
(
ctx
,
hiddenState
)
.
GELU
(
ctx
)
hiddenState
=
mlp
.
FC2
.
Forward
(
ctx
,
hiddenState
)
return
hiddenState
}
type
VisionEncoderLayer
struct
{
LayerNorm1
*
nn
.
LayerNorm
`gguf:"layer_norm1"`
SelfAttention
*
VisionSelfAttention
LayerNorm2
*
nn
.
LayerNorm
`gguf:"layer_norm2"`
MLP
*
VisionMLP
`gguf:"mlp"`
}
func
(
e
*
VisionEncoderLayer
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
VisionModelOptions
)
ml
.
Tensor
{
residual
:=
hiddenState
// self attention
hiddenState
=
e
.
LayerNorm1
.
Forward
(
ctx
,
hiddenState
,
opts
.
eps
)
hiddenState
=
e
.
SelfAttention
.
Forward
(
ctx
,
hiddenState
,
opts
)
hiddenState
=
hiddenState
.
Add
(
ctx
,
residual
)
residual
=
hiddenState
// feed forward
hiddenState
=
e
.
LayerNorm2
.
Forward
(
ctx
,
hiddenState
,
opts
.
eps
)
hiddenState
=
e
.
MLP
.
Forward
(
ctx
,
hiddenState
,
opts
)
return
hiddenState
.
Add
(
ctx
,
residual
)
}
type
VisionModelOptions
struct
{
hiddenSize
,
numHeads
int
imageSize
,
patchSize
int
eps
float32
}
type
VisionModel
struct
{
PatchEmbedding
*
nn
.
Conv2D
`gguf:"patch_embedding"`
PositionEmbedding
*
nn
.
Embedding
`gguf:"position_embedding"`
PostLayerNorm
*
nn
.
LayerNorm
`gguf:"post_layernorm"`
Layers
[]
VisionEncoderLayer
`gguf:"blk"`
*
VisionModelOptions
}
func
(
m
*
VisionModel
)
Forward
(
ctx
ml
.
Context
,
pixelValues
ml
.
Tensor
)
ml
.
Tensor
{
numPatches
:=
(
m
.
imageSize
/
m
.
patchSize
)
*
(
m
.
imageSize
/
m
.
patchSize
)
hiddenState
:=
m
.
PatchEmbedding
.
Forward
(
ctx
,
pixelValues
,
m
.
patchSize
,
m
.
patchSize
,
0
,
0
,
1
,
1
)
hiddenState
=
hiddenState
.
Reshape
(
ctx
,
numPatches
,
m
.
hiddenSize
)
hiddenState
=
hiddenState
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
positions
:=
make
([]
int32
,
numPatches
)
for
i
:=
range
positions
{
positions
[
i
]
=
int32
(
i
)
}
positionIDs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
positions
,
len
(
positions
))
if
err
!=
nil
{
panic
(
err
)
}
hiddenState
=
hiddenState
.
Add
(
ctx
,
m
.
PositionEmbedding
.
Forward
(
ctx
,
positionIDs
))
for
_
,
layer
:=
range
m
.
Layers
{
hiddenState
=
layer
.
Forward
(
ctx
,
hiddenState
,
m
.
VisionModelOptions
)
}
hiddenState
=
m
.
PostLayerNorm
.
Forward
(
ctx
,
hiddenState
,
m
.
eps
)
return
hiddenState
}
func
newVisionModel
(
c
ml
.
Config
)
*
VisionModel
{
return
&
VisionModel
{
Layers
:
make
([]
VisionEncoderLayer
,
c
.
Uint
(
"vision.block_count"
)),
VisionModelOptions
:
&
VisionModelOptions
{
hiddenSize
:
int
(
c
.
Uint
(
"vision.embedding_length"
)),
numHeads
:
int
(
c
.
Uint
(
"vision.attention.head_count"
)),
imageSize
:
int
(
c
.
Uint
(
"vision.image_size"
)),
patchSize
:
int
(
c
.
Uint
(
"vision.patch_size"
)),
eps
:
c
.
Float
(
"vision.attention.layer_norm_epsilon"
),
},
}
}
model/models/gemma3/process_image.go
0 → 100644
View file @
aee28501
package
gemma3
import
(
"image"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type
ImageProcessor
struct
{
imageSize
,
patchSize
,
numChannels
int
}
func
newImageProcessor
(
c
ml
.
Config
)
ImageProcessor
{
return
ImageProcessor
{
imageSize
:
int
(
c
.
Uint
(
"vision.image_size"
)),
patchSize
:
int
(
c
.
Uint
(
"vision.patch_size"
)),
numChannels
:
int
(
c
.
Uint
(
"vision.num_channels"
)),
}
}
func
(
p
*
ImageProcessor
)
pack
(
img
image
.
Image
,
mean
,
std
[
3
]
float32
)
[]
float32
{
var
pixelVals
,
rVals
,
gVals
,
bVals
[]
float32
bounds
:=
img
.
Bounds
()
for
y
:=
bounds
.
Min
.
Y
;
y
<
bounds
.
Max
.
Y
;
y
++
{
for
x
:=
bounds
.
Min
.
X
;
x
<
bounds
.
Max
.
X
;
x
++
{
c
:=
img
.
At
(
x
,
y
)
r
,
g
,
b
,
_
:=
c
.
RGBA
()
rVal
:=
float32
(
r
>>
8
)
/
255.0
gVal
:=
float32
(
g
>>
8
)
/
255.0
bVal
:=
float32
(
b
>>
8
)
/
255.0
rVal
=
(
rVal
-
mean
[
0
])
/
std
[
0
]
gVal
=
(
gVal
-
mean
[
1
])
/
std
[
1
]
bVal
=
(
bVal
-
mean
[
2
])
/
std
[
2
]
rVals
=
append
(
rVals
,
rVal
)
gVals
=
append
(
gVals
,
gVal
)
bVals
=
append
(
bVals
,
bVal
)
}
}
pixelVals
=
append
(
pixelVals
,
rVals
...
)
pixelVals
=
append
(
pixelVals
,
gVals
...
)
pixelVals
=
append
(
pixelVals
,
bVals
...
)
return
pixelVals
}
func
(
p
ImageProcessor
)
ProcessImage
(
img
image
.
Image
)
([]
float32
,
error
)
{
outputSize
:=
image
.
Point
{
p
.
imageSize
,
p
.
imageSize
}
newImage
:=
imageproc
.
Composite
(
img
)
newImage
=
imageproc
.
Resize
(
newImage
,
outputSize
,
imageproc
.
ResizeBilinear
)
data
:=
p
.
pack
(
newImage
,
imageproc
.
ImageNetStandardMean
,
imageproc
.
ImageNetStandardSTD
)
return
data
,
nil
}
model/models/llama/model.go
View file @
aee28501
...
...
@@ -76,14 +76,15 @@ type SelfAttention struct {
func
(
sa
*
SelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
Options
)
ml
.
Tensor
{
batchSize
:=
hiddenState
.
Dim
(
1
)
headDim
:=
opts
.
hiddenSize
/
opts
.
numHeads
ropeType
:=
uint32
(
0
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
=
q
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
RoPE
(
ctx
,
positionIDs
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
)
q
=
q
.
RoPE
(
ctx
,
positionIDs
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
=
k
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
RoPE
(
ctx
,
positionIDs
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
)
k
=
k
.
RoPE
(
ctx
,
positionIDs
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
=
v
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
...
...
@@ -96,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func
(
m
*
Model
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
return
key
.
RoPE
(
ctx
,
shift
,
m
.
Layers
[
layer
]
.
SelfAttention
.
RopeFactors
,
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
),
nil
return
key
.
RoPE
(
ctx
,
shift
,
m
.
Layers
[
layer
]
.
SelfAttention
.
RopeFactors
,
uint32
(
0
),
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
),
nil
}
type
MLP
struct
{
...
...
model/models/mllama/model_text.go
View file @
aee28501
...
...
@@ -20,14 +20,15 @@ type TextSelfAttention struct {
func
(
sa
*
TextSelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positions
,
_
ml
.
Tensor
,
cache
*
kvcache
.
WrapperCache
,
opts
*
TextModelOptions
)
ml
.
Tensor
{
batchSize
:=
hiddenState
.
Dim
(
1
)
headDim
:=
opts
.
hiddenSize
/
opts
.
numHeads
ropeType
:=
uint32
(
0
)
query
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
query
=
query
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
query
=
query
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
)
query
=
query
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
key
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
key
=
key
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
key
=
key
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
)
key
=
key
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
value
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
value
=
value
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
...
...
@@ -40,8 +41,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
}
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
// This will only get called for layers in the cache, which are just the self attention layers
if
sa
,
ok
:=
m
.
Transformer
.
Layers
[
layer
]
.
(
*
TextSelfAttentionDecoderLayer
);
ok
{
return
key
.
RoPE
(
ctx
,
shift
,
sa
.
SelfAttention
.
RopeFactors
,
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
),
nil
return
key
.
RoPE
(
ctx
,
shift
,
sa
.
SelfAttention
.
RopeFactors
,
m
.
ropeDim
,
uint32
(
0
),
m
.
ropeBase
,
m
.
ropeScale
),
nil
}
return
key
,
nil
...
...
model/models/mllama/process_image.go
View file @
aee28501
...
...
@@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
return
images
}
// remove the "alpha" channel by drawing over a prefilled image
//
// remove the "alpha" channel by drawing over a prefilled image
//
//nolint:unused
...
...
model/models/models.go
View file @
aee28501
package
models
import
(
_
"github.com/ollama/ollama/model/models/gemma2"
_
"github.com/ollama/ollama/model/models/gemma3"
_
"github.com/ollama/ollama/model/models/llama"
_
"github.com/ollama/ollama/model/models/mllama"
)
model/process_text.go
View file @
aee28501
...
...
@@ -4,6 +4,7 @@ import (
"cmp"
"iter"
"log/slog"
"slices"
"strings"
"sync"
...
...
@@ -18,6 +19,15 @@ const (
SpecialEOS
)
const
(
TOKEN_TYPE_NORMAL
=
iota
+
1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type
TextProcessor
interface
{
Encode
(
s
string
,
addSpecial
bool
)
([]
int32
,
error
)
Decode
([]
int32
)
(
string
,
error
)
...
...
@@ -27,11 +37,11 @@ type TextProcessor interface {
type
Vocabulary
struct
{
Values
[]
string
Types
[]
uint32
Scores
[]
uin
t32
Scores
[]
floa
t32
Merges
[]
string
BOS
,
EOS
int32
AddBOS
,
AddEOS
bool
BOS
,
EOS
,
EOT
int32
AddBOS
,
AddEOS
,
AddEOT
bool
specialOnce
sync
.
Once
special
[]
string
...
...
@@ -48,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
case
SpecialBOS
:
return
id
==
v
.
BOS
case
SpecialEOS
:
return
id
==
v
.
EOS
return
id
==
v
.
EOS
||
id
==
v
.
EOT
default
:
return
false
}
...
...
@@ -76,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
func
(
v
*
Vocabulary
)
SpecialVocabulary
()
[]
string
{
v
.
specialOnce
.
Do
(
func
()
{
for
i
:=
range
v
.
Values
{
if
v
.
Types
[
i
]
==
3
{
if
slices
.
Contains
([]
int
{
105
,
106
},
i
)
{
v
.
special
=
append
(
v
.
special
,
v
.
Values
[
i
])
}
else
if
v
.
Types
[
i
]
==
TOKEN_TYPE_CONTROL
{
v
.
special
=
append
(
v
.
special
,
v
.
Values
[
i
])
}
}
...
...
model/process_text_spm.go
0 → 100644
View file @
aee28501
package
model
import
(
"iter"
"log/slog"
"strings"
"github.com/dlclark/regexp2"
queue
"github.com/emirpasic/gods/v2/queues/priorityqueue"
)
const
spmWhitespaceSep
=
"▁"
func
replaceWhitespaceBySeperator
(
s
string
)
string
{
return
strings
.
ReplaceAll
(
s
,
" "
,
spmWhitespaceSep
)
}
type
SentencePieceModel
struct
{
maxTokenLen
int
pre
*
regexp2
.
Regexp
vocab
*
Vocabulary
}
var
_
TextProcessor
=
(
*
SentencePieceModel
)(
nil
)
func
NewSentencePieceModel
(
pre
string
,
vocab
*
Vocabulary
)
SentencePieceModel
{
slog
.
Debug
(
"Tokens"
,
"num tokens"
,
len
(
vocab
.
Values
),
"vals"
,
vocab
.
Values
[
:
5
],
"scores"
,
vocab
.
Scores
[
:
5
],
"types"
,
vocab
.
Types
[
:
5
])
counter
:=
map
[
int
]
int
{}
var
maxTokenLen
int
for
cnt
:=
range
vocab
.
Types
{
switch
vocab
.
Types
[
cnt
]
{
case
TOKEN_TYPE_NORMAL
,
TOKEN_TYPE_USER_DEFINED
,
TOKEN_TYPE_UNUSED
:
maxTokenLen
=
max
(
maxTokenLen
,
len
(
vocab
.
Values
[
cnt
]))
fallthrough
default
:
counter
[
int
(
vocab
.
Types
[
cnt
])]
+=
1
}
}
slog
.
Debug
(
"Token counts"
,
"normal"
,
counter
[
TOKEN_TYPE_NORMAL
],
"unknown"
,
counter
[
TOKEN_TYPE_UNKNOWN
],
"control"
,
counter
[
TOKEN_TYPE_CONTROL
],
"user defined"
,
counter
[
TOKEN_TYPE_USER_DEFINED
],
"unused"
,
counter
[
TOKEN_TYPE_UNUSED
],
"byte"
,
counter
[
TOKEN_TYPE_BYTE
],
"max token len"
,
maxTokenLen
)
return
SentencePieceModel
{
maxTokenLen
:
maxTokenLen
,
pre
:
regexp2
.
MustCompile
(
pre
,
regexp2
.
Unicode
|
regexp2
.
RE2
),
vocab
:
vocab
,
}
}
func
(
spm
SentencePieceModel
)
Is
(
id
int32
,
special
Special
)
bool
{
return
spm
.
vocab
.
Is
(
id
,
special
)
}
func
(
spm
*
SentencePieceModel
)
split
(
s
string
)
iter
.
Seq
[
string
]
{
return
func
(
yield
func
(
string
)
bool
)
{
for
m
,
_
:=
spm
.
pre
.
FindStringMatch
(
s
);
m
!=
nil
;
m
,
_
=
spm
.
pre
.
FindNextMatch
(
m
)
{
if
!
yield
(
m
.
String
())
{
break
}
}
}
}
func
(
spm
SentencePieceModel
)
Encode
(
s
string
,
addSpecial
bool
)
([]
int32
,
error
)
{
fragments
:=
[]
fragment
{{
value
:
s
}}
for
_
,
special
:=
range
spm
.
vocab
.
SpecialVocabulary
()
{
// TODO: process special tokens concurrently
id
:=
spm
.
vocab
.
Encode
(
special
)
for
i
:=
0
;
i
<
len
(
fragments
);
i
++
{
frag
:=
fragments
[
i
]
if
len
(
frag
.
ids
)
>
0
{
continue
}
var
middle
[]
fragment
switch
i
:=
strings
.
Index
(
frag
.
value
,
special
);
{
case
i
<
0
:
middle
=
append
(
middle
,
frag
)
case
i
>
0
:
middle
=
append
(
middle
,
fragment
{
value
:
frag
.
value
[
:
i
]})
fallthrough
default
:
middle
=
append
(
middle
,
fragment
{
value
:
special
,
ids
:
[]
int32
{
id
}})
if
rest
:=
frag
.
value
[
i
+
len
(
special
)
:
];
rest
!=
""
{
middle
=
append
(
middle
,
fragment
{
value
:
rest
})
}
}
fragments
=
append
(
fragments
[
:
i
],
append
(
middle
,
fragments
[
i
+
1
:
]
...
)
...
)
}
}
slog
.
Debug
(
"fragments"
,
"frags"
,
fragments
)
var
ids
[]
int32
for
_
,
frag
:=
range
fragments
{
if
len
(
frag
.
ids
)
>
0
{
ids
=
append
(
ids
,
frag
.
ids
...
)
continue
}
for
split
:=
range
spm
.
split
(
frag
.
value
)
{
split
=
replaceWhitespaceBySeperator
(
split
)
var
sb
strings
.
Builder
sb
.
Write
([]
byte
(
split
))
if
id
:=
spm
.
vocab
.
Encode
(
sb
.
String
());
id
>=
0
{
ids
=
append
(
ids
,
id
)
continue
}
runes
:=
[]
rune
(
sb
.
String
())
pq
:=
queue
.
NewWith
(
func
(
a
,
b
any
)
int
{
priA
:=
a
.
(
*
candidate
)
priB
:=
b
.
(
*
candidate
)
if
priA
.
score
>
priB
.
score
||
(
priA
.
score
==
priB
.
score
&&
priA
.
a
<
priB
.
a
)
{
return
-
1
}
return
1
})
merges
:=
make
([]
merge
,
len
(
runes
))
for
r
:=
range
runes
{
merges
[
r
]
=
merge
{
p
:
r
-
1
,
n
:
r
+
1
,
runes
:
[]
rune
{
runes
[
r
]},
}
}
slog
.
Debug
(
"tokenizer"
,
"merges"
,
merges
)
pairwise
:=
func
(
a
,
b
int
)
*
candidate
{
if
a
<
0
||
b
>=
len
(
runes
)
{
return
nil
}
left
,
right
:=
string
(
merges
[
a
]
.
runes
),
string
(
merges
[
b
]
.
runes
)
if
id
:=
spm
.
vocab
.
Encode
(
left
+
right
);
id
>=
0
{
return
&
candidate
{
a
:
a
,
b
:
b
,
score
:
spm
.
vocab
.
Scores
[
id
],
}
}
return
nil
}
for
i
:=
range
len
(
runes
)
-
1
{
if
pair
:=
pairwise
(
i
,
i
+
1
);
pair
!=
nil
{
pq
.
Enqueue
(
pair
)
}
}
pqv
:=
pq
.
Values
()
for
_
,
v
:=
range
pqv
{
e
:=
v
.
(
*
candidate
)
slog
.
Debug
(
"candidate"
,
"candidate"
,
e
)
}
for
!
pq
.
Empty
()
{
v
,
_
:=
pq
.
Dequeue
()
pair
:=
v
.
(
*
candidate
)
left
,
right
:=
merges
[
pair
.
a
],
merges
[
pair
.
b
]
slog
.
Debug
(
"pair"
,
"left"
,
left
,
"right"
,
right
)
if
len
(
left
.
runes
)
==
0
||
len
(
right
.
runes
)
==
0
{
continue
}
if
id
:=
spm
.
vocab
.
Encode
(
string
(
left
.
runes
)
+
string
(
right
.
runes
));
id
<
0
{
continue
}
merges
[
pair
.
a
]
.
runes
=
append
(
left
.
runes
,
right
.
runes
...
)
merges
[
pair
.
b
]
.
runes
=
nil
merges
[
pair
.
a
]
.
n
=
right
.
n
if
right
.
n
<
len
(
merges
)
{
merges
[
right
.
n
]
.
p
=
pair
.
a
}
if
pair
:=
pairwise
(
merges
[
pair
.
a
]
.
p
,
pair
.
a
);
pair
!=
nil
{
pq
.
Enqueue
(
pair
)
}
if
pair
:=
pairwise
(
pair
.
a
,
merges
[
pair
.
a
]
.
n
);
pair
!=
nil
{
pq
.
Enqueue
(
pair
)
}
}
slog
.
Debug
(
"merges"
,
"merges"
,
merges
)
for
_
,
merge
:=
range
merges
{
if
len
(
merge
.
runes
)
>
0
{
if
id
:=
spm
.
vocab
.
Encode
(
string
(
merge
.
runes
));
id
>=
0
{
ids
=
append
(
ids
,
id
)
}
else
{
slog
.
Debug
(
"missing token"
,
"token"
,
string
(
merge
.
runes
))
}
}
}
}
}
if
addSpecial
&&
len
(
ids
)
>
0
{
if
spm
.
vocab
.
AddBOS
{
if
ids
[
0
]
==
spm
.
vocab
.
BOS
{
slog
.
Warn
(
"adding bos token to prompt which already has it"
,
"id"
,
spm
.
vocab
.
BOS
)
}
slog
.
Debug
(
"adding bos token to prompt"
,
"id"
,
spm
.
vocab
.
BOS
)
ids
=
append
([]
int32
{
spm
.
vocab
.
BOS
},
ids
...
)
}
if
spm
.
vocab
.
AddEOS
{
if
ids
[
len
(
ids
)
-
1
]
==
spm
.
vocab
.
EOS
{
slog
.
Warn
(
"adding eos token to prompt which already has it"
,
"id"
,
spm
.
vocab
.
EOS
)
}
slog
.
Debug
(
"adding eos token to prompt"
,
"id"
,
spm
.
vocab
.
EOS
)
ids
=
append
(
ids
,
spm
.
vocab
.
EOS
)
}
}
return
ids
,
nil
}
type
candidate
struct
{
a
,
b
int
score
float32
}
func
(
spm
SentencePieceModel
)
Decode
(
ids
[]
int32
)
(
string
,
error
)
{
var
sb
strings
.
Builder
for
_
,
id
:=
range
ids
{
data
:=
spm
.
vocab
.
Decode
(
id
)
data
=
strings
.
ReplaceAll
(
data
,
spmWhitespaceSep
,
" "
)
if
_
,
err
:=
sb
.
WriteString
(
data
);
err
!=
nil
{
return
""
,
err
}
}
slog
.
Debug
(
"decoded"
,
"ids"
,
ids
,
"text"
,
sb
.
String
())
return
sb
.
String
(),
nil
}
model/process_text_spm_test.go
0 → 100644
View file @
aee28501
package
model
import
(
"log/slog"
"os"
"path/filepath"
"slices"
"testing"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
func
loadSentencePieceVocab
(
t
*
testing
.
T
)
SentencePieceModel
{
t
.
Helper
()
bts
,
err
:=
os
.
ReadFile
(
filepath
.
Join
(
"testdata"
,
"gemma2"
,
"tokenizer.model"
))
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
var
spm
sentencepiece
.
ModelProto
if
err
:=
proto
.
Unmarshal
(
bts
,
&
spm
);
err
!=
nil
{
t
.
Fatal
(
err
)
}
preTokenizer
:=
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
var
v
Vocabulary
for
_
,
piece
:=
range
spm
.
GetPieces
()
{
v
.
Values
=
append
(
v
.
Values
,
piece
.
GetPiece
())
v
.
Scores
=
append
(
v
.
Scores
,
piece
.
GetScore
())
switch
t
:=
piece
.
GetType
();
t
{
case
sentencepiece
.
ModelProto_SentencePiece_UNKNOWN
,
sentencepiece
.
ModelProto_SentencePiece_CONTROL
,
sentencepiece
.
ModelProto_SentencePiece_UNUSED
,
sentencepiece
.
ModelProto_SentencePiece_BYTE
:
v
.
Types
=
append
(
v
.
Types
,
uint32
(
t
))
default
:
tt
:=
uint32
(
sentencepiece
.
ModelProto_SentencePiece_NORMAL
)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
v
.
Types
=
append
(
v
.
Types
,
tt
)
}
}
return
NewSentencePieceModel
(
preTokenizer
,
&
v
)
}
func
TestSentencePieceEncode
(
t
*
testing
.
T
)
{
logger
:=
slog
.
New
(
slog
.
NewTextHandler
(
os
.
Stdout
,
&
slog
.
HandlerOptions
{
Level
:
slog
.
LevelDebug
}))
slog
.
SetDefault
(
logger
)
tokenizer
:=
loadSentencePieceVocab
(
t
)
t
.
Run
(
"basic roundtrip"
,
func
(
t
*
testing
.
T
)
{
t
.
Parallel
()
cases
:=
[]
string
{
"hello"
,
"hello "
,
"hello "
,
" hello"
,
" hello "
,
" hello "
,
"hello world"
,
"请考试我的软件!12345"
,
"你好"
,
"Hello 你好 world!"
,
"Special characters: !@#$%^&*()_+-=[]{}|;':
\"
,./<>?"
,
"Multilingual: 你好 こんにちは Привет Hola مرحبا"
,
"Numbers and symbols: 123456789 +- */"
,
"Special tokens: <bos> text <eos>"
,
"Code snippets: func main() { fmt.Println(
\"
Hello World
\"
) }"
,
"Long text: "
+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
+
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. "
+
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris."
,
}
for
_
,
want
:=
range
cases
{
ids
,
err
:=
tokenizer
.
Encode
(
want
,
true
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
if
got
,
err
:=
tokenizer
.
Decode
(
ids
);
err
!=
nil
{
t
.
Fatal
(
err
)
}
else
if
got
!=
want
{
t
.
Errorf
(
"got %q, want %q [%#v]"
,
got
,
want
,
ids
)
}
}
})
t
.
Run
(
"special tokens"
,
func
(
t
*
testing
.
T
)
{
type
candidate
struct
{
token
string
ids
[]
int32
}
cases
:=
[]
candidate
{
{
"<bos>"
,
[]
int32
{
2
}},
{
"<eos>"
,
[]
int32
{
1
}},
}
for
_
,
want
:=
range
cases
{
ids
,
err
:=
tokenizer
.
Encode
(
want
.
token
,
true
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
if
!
slices
.
Equal
(
ids
,
want
.
ids
)
{
t
.
Errorf
(
"got %#v, want %#v"
,
ids
,
want
.
ids
)
}
}
})
}
model/testdata/gemma2/tokenizer.model
0 → 100644
View file @
aee28501
File added
server/prompt.go
View file @
aee28501
...
...
@@ -26,6 +26,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var
system
[]
api
.
Message
isMllama
:=
checkMllamaModelFamily
(
m
)
isGemma3
:=
checkGemma3ModelFamily
(
m
)
var
imageNumTokens
int
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
...
...
@@ -40,7 +41,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
n
:=
len
(
msgs
)
-
1
// in reverse, find all messages that fit into context window
for
i
:=
n
;
i
>=
0
;
i
--
{
if
isMllama
&&
len
(
msgs
[
i
]
.
Images
)
>
1
{
if
(
isMllama
||
isGemma3
)
&&
len
(
msgs
[
i
]
.
Images
)
>
1
{
return
""
,
nil
,
errTooManyImages
}
...
...
@@ -157,3 +158,12 @@ func checkMllamaModelFamily(m *Model) bool {
}
return
false
}
func
checkGemma3ModelFamily
(
m
*
Model
)
bool
{
for
_
,
arch
:=
range
m
.
Config
.
ModelFamilies
{
if
arch
==
"gemma3"
{
return
true
}
}
return
false
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment