Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
c253433d
Unverified
Commit
c253433d
authored
Sep 16, 2025
by
Michael Yang
Committed by
GitHub
Sep 16, 2025
Browse files
embed: cleanup (#12299)
* cleanup * use pooling.TypeNone * pooling test
parent
a1cff89b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
104 additions
and
19 deletions
+104
-19
ml/nn/pooling/pooling.go
ml/nn/pooling/pooling.go
+18
-12
ml/nn/pooling/pooling_test.go
ml/nn/pooling/pooling_test.go
+79
-0
model/model.go
model/model.go
+2
-2
model/models/bert/embed.go
model/models/bert/embed.go
+1
-1
model/models/gemma3/embed.go
model/models/gemma3/embed.go
+1
-1
runner/ollamarunner/runner.go
runner/ollamarunner/runner.go
+3
-3
No files found.
ml/nn/pooling/pooling.go
View file @
c253433d
...
@@ -11,26 +11,32 @@ const (
...
@@ -11,26 +11,32 @@ const (
TypeMean
TypeMean
TypeCLS
TypeCLS
TypeLast
TypeLast
TypeRank
TypeUnknown
=
0xFFFFFFFE
TypeUnspecified
=
0xFFFFFFFF
)
)
func
Pooling
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
poolingType
Type
)
ml
.
Tensor
{
func
(
t
Type
)
String
()
string
{
switch
poolingType
{
switch
t
{
case
TypeNone
:
case
TypeMean
:
return
hiddenStates
return
"Mean"
case
TypeCLS
:
return
"CLS"
case
TypeLast
:
return
"Last"
default
:
return
"Unknown"
}
}
func
(
t
Type
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
)
ml
.
Tensor
{
switch
t
{
case
TypeMean
:
case
TypeMean
:
hiddenStates
=
hiddenStates
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
.
Mean
(
ctx
)
hiddenStates
=
hiddenStates
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
.
Mean
(
ctx
)
return
hiddenStates
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
return
hiddenStates
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
case
TypeCLS
:
case
TypeCLS
:
return
hiddenStates
.
View
(
ctx
,
0
,
hiddenStates
.
Dim
(
0
))
return
hiddenStates
.
View
(
ctx
,
0
,
hiddenStates
.
Dim
(
0
))
case
TypeLast
:
case
TypeLast
:
panic
(
"not implemented"
)
hiddenStates
=
hiddenStates
.
View
(
ctx
,
(
hiddenStates
.
Dim
(
1
)
-
1
)
*
hiddenStates
.
Stride
(
1
),
hiddenStates
.
Dim
(
0
))
case
TypeRank
:
return
hiddenStates
panic
(
"not implemented"
)
default
:
default
:
panic
(
"
not implemented
"
)
panic
(
"
unknown pooling type
"
)
}
}
}
}
ml/nn/pooling/pooling_test.go
0 → 100644
View file @
c253433d
package
pooling_test
import
(
"bytes"
"os"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/discover"
fsggml
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn/pooling"
)
func
setup
(
tb
testing
.
TB
,
n
int
)
ml
.
Backend
{
tb
.
Helper
()
f
,
err
:=
os
.
CreateTemp
(
tb
.
TempDir
(),
"*.bin"
)
if
err
!=
nil
{
tb
.
Fatal
(
err
)
}
defer
f
.
Close
()
if
err
:=
fsggml
.
WriteGGUF
(
f
,
fsggml
.
KV
{
"general.architecture"
:
"test"
,
"test.block_count"
:
uint32
(
1
),
},
[]
*
fsggml
.
Tensor
{
{
Name
:
"blk.0.weight"
,
Shape
:
[]
uint64
{
1
},
WriterTo
:
bytes
.
NewBuffer
(
make
([]
byte
,
4
))},
});
err
!=
nil
{
tb
.
Fatal
(
err
)
}
var
gpuLayers
ml
.
GPULayersList
if
gpus
:=
discover
.
GetGPUInfo
();
len
(
gpus
)
>
0
{
gpuLayers
=
append
(
gpuLayers
,
ml
.
GPULayers
{
ID
:
gpus
[
0
]
.
ID
,
Layers
:
slices
.
Collect
(
func
(
yield
func
(
int
)
bool
)
{
for
i
:=
range
n
{
if
!
yield
(
i
)
{
return
}
}
}),
})
}
b
,
err
:=
ggml
.
New
(
f
.
Name
(),
ml
.
BackendParams
{
AllocMemory
:
true
,
GPULayers
:
gpuLayers
})
if
err
!=
nil
{
tb
.
Fatal
(
err
)
}
return
b
}
func
TestForward
(
t
*
testing
.
T
)
{
cases
:=
map
[
pooling
.
Type
][]
float32
{
pooling
.
TypeMean
:
{
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
},
pooling
.
TypeCLS
:
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
},
pooling
.
TypeLast
:
{
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
},
}
for
typ
,
want
:=
range
cases
{
t
.
Run
(
typ
.
String
(),
func
(
t
*
testing
.
T
)
{
b
:=
setup
(
t
,
99
)
defer
b
.
Close
()
ctx
:=
b
.
NewContext
()
defer
ctx
.
Close
()
tt
:=
ctx
.
Input
()
.
Arange
(
0
,
16
,
1
,
ml
.
DTypeF32
)
.
Reshape
(
ctx
,
8
,
2
)
tt
=
typ
.
Forward
(
ctx
,
tt
)
ctx
.
Forward
(
tt
)
.
Compute
(
tt
)
if
diff
:=
cmp
.
Diff
(
want
,
tt
.
Floats
());
diff
!=
""
{
t
.
Error
(
diff
)
}
})
}
}
model/model.go
View file @
c253433d
...
@@ -5,7 +5,6 @@ import (
...
@@ -5,7 +5,6 @@ import (
"fmt"
"fmt"
_
"image/jpeg"
_
"image/jpeg"
_
"image/png"
_
"image/png"
"math"
"os"
"os"
"reflect"
"reflect"
"strconv"
"strconv"
...
@@ -21,6 +20,7 @@ import (
...
@@ -21,6 +20,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
_
"github.com/ollama/ollama/ml/backend"
_
"github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/input"
)
)
...
@@ -108,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
...
@@ -108,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
}
}
arch
:=
b
.
Config
()
.
Architecture
()
arch
:=
b
.
Config
()
.
Architecture
()
if
b
.
Config
()
.
Uint
(
"pooling_type"
,
math
.
MaxUint32
)
!=
math
.
MaxUint32
{
if
pooling
.
Type
(
b
.
Config
()
.
Uint
(
"pooling_type"
))
!=
pooling
.
TypeNone
{
arch
=
arch
+
"_embed"
arch
=
arch
+
"_embed"
}
}
...
...
model/models/bert/
model
.go
→
model/models/bert/
embed
.go
View file @
c253433d
...
@@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
...
@@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates
=
layer
.
Forward
(
ctx
,
hiddenStates
,
&
m
.
Options
)
hiddenStates
=
layer
.
Forward
(
ctx
,
hiddenStates
,
&
m
.
Options
)
}
}
hiddenStates
=
pooling
.
Pooling
(
ctx
,
hiddenStates
,
m
.
poolingType
)
hiddenStates
=
m
.
pooling
Type
.
Forward
(
ctx
,
hiddenStates
)
if
m
.
normalize
{
if
m
.
normalize
{
hiddenStates
=
hiddenStates
.
L2Norm
(
ctx
,
1e-12
)
hiddenStates
=
hiddenStates
.
L2Norm
(
ctx
,
1e-12
)
}
}
...
...
model/models/gemma3/embed.go
View file @
c253433d
...
@@ -22,7 +22,7 @@ type embedModel struct {
...
@@ -22,7 +22,7 @@ type embedModel struct {
func
(
m
*
embedModel
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
embedModel
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
hiddenStates
:=
m
.
TextModel
.
Forward
(
ctx
,
batch
,
m
.
Cache
)
hiddenStates
:=
m
.
TextModel
.
Forward
(
ctx
,
batch
,
m
.
Cache
)
hiddenStates
=
pooling
.
Pooling
(
ctx
,
hiddenStates
,
m
.
poolingType
)
hiddenStates
=
m
.
pooling
Type
.
Forward
(
ctx
,
hiddenStates
)
for
_
,
dense
:=
range
m
.
Dense
{
for
_
,
dense
:=
range
m
.
Dense
{
hiddenStates
=
dense
.
Forward
(
ctx
,
hiddenStates
)
hiddenStates
=
dense
.
Forward
(
ctx
,
hiddenStates
)
}
}
...
...
runner/ollamarunner/runner.go
View file @
c253433d
...
@@ -11,7 +11,6 @@ import (
...
@@ -11,7 +11,6 @@ import (
"image"
"image"
"log"
"log"
"log/slog"
"log/slog"
"math"
"net"
"net"
"net/http"
"net/http"
"os"
"os"
...
@@ -32,6 +31,7 @@ import (
...
@@ -32,6 +31,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/runner/common"
...
@@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
...
@@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func
(
s
*
Server
)
run
(
ctx
context
.
Context
)
{
func
(
s
*
Server
)
run
(
ctx
context
.
Context
)
{
s
.
ready
.
Wait
()
s
.
ready
.
Wait
()
supportsAsync
:=
s
.
model
.
Backend
()
.
Config
()
.
Uint
(
"pooling_type"
,
math
.
MaxUint32
)
==
math
.
MaxUint32
supportsAsync
:=
pooling
.
Type
(
s
.
model
.
Backend
()
.
Config
()
.
Uint
(
"pooling_type"
))
==
pooling
.
TypeNone
var
activeBatch
batchState
var
activeBatch
batchState
for
{
for
{
...
@@ -900,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
...
@@ -900,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
func
(
s
*
Server
)
embeddings
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
func
(
s
*
Server
)
embeddings
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
if
s
.
model
.
Backend
()
.
Config
()
.
Uint
(
"pooling_type"
,
math
.
MaxUint32
)
==
math
.
MaxUint32
{
if
pooling
.
Type
(
s
.
model
.
Backend
()
.
Config
()
.
Uint
(
"pooling_type"
))
==
pooling
.
TypeNone
{
http
.
Error
(
w
,
"this model does not support embeddings"
,
http
.
StatusNotImplemented
)
http
.
Error
(
w
,
"this model does not support embeddings"
,
http
.
StatusNotImplemented
)
return
return
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment