Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
76f88caf
Unverified
Commit
76f88caf
authored
Dec 09, 2025
by
nicole pardal
Committed by
GitHub
Dec 09, 2025
Browse files
nomic-embed-text:v2: model implementation (#13162)
parent
2bccf8c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
304 additions
and
18 deletions
+304
-18
convert/convert.go
convert/convert.go
+2
-0
convert/convert_nomicbert.go
convert/convert_nomicbert.go
+213
-0
model/models/nomicbert/model.go
model/models/nomicbert/model.go
+89
-18
No files found.
convert/convert.go
View file @
76f88caf
...
@@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
...
@@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv
=
&
qwen3VLModel
{}
conv
=
&
qwen3VLModel
{}
case
"BertModel"
:
case
"BertModel"
:
conv
=
&
bertModel
{}
conv
=
&
bertModel
{}
case
"NomicBertModel"
,
"NomicBertMoEModel"
:
conv
=
&
nomicbertModel
{}
case
"CohereForCausalLM"
:
case
"CohereForCausalLM"
:
conv
=
&
commandrModel
{}
conv
=
&
commandrModel
{}
case
"GptOssForCausalLM"
:
case
"GptOssForCausalLM"
:
...
...
convert/convert_nomicbert.go
0 → 100644
View file @
76f88caf
package
convert
import
(
"cmp"
"encoding/json"
"io/fs"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type
nomicbertModel
struct
{
ModelParameters
NLayers
uint32
`json:"n_layers"`
NumHiddenLayers
uint32
`json:"num_hidden_layers"`
MaxPositionEmbeddings
uint32
`json:"max_position_embeddings"`
HiddenSize
uint32
`json:"hidden_size"`
IntermediateSize
uint32
`json:"intermediate_size"`
NumAttentionHeads
uint32
`json:"num_attention_heads"`
NumKeyValueHeads
uint32
`json:"num_key_value_heads"`
LayerNormEPS
float32
`json:"layer_norm_eps"`
LayerNormEpsilon
float32
`json:"layer_norm_epsilon"`
RopeFreqBase
float32
`json:"rope_theta"`
normalizeEmbeddings
bool
PoolingType
uint32
// MoE parameters (only present in v2 models)
NumExperts
uint32
`json:"num_local_experts"`
NumExpertsUsed
uint32
`json:"num_experts_per_tok"`
MoEEveryNLayers
uint32
`json:"moe_every_n_layers"`
}
var
(
_
ModelConverter
=
(
*
nomicbertModel
)(
nil
)
_
moreParser
=
(
*
nomicbertModel
)(
nil
)
)
func
(
p
*
nomicbertModel
)
parseMore
(
fsys
fs
.
FS
)
error
{
bts
,
err
:=
fs
.
ReadFile
(
fsys
,
"modules.json"
)
if
err
!=
nil
{
return
err
}
var
modules
[]
struct
{
Type
string
`json:"type"`
Path
string
`json:"path"`
}
if
err
:=
json
.
Unmarshal
(
bts
,
&
modules
);
err
!=
nil
{
return
err
}
var
pooling
string
for
_
,
m
:=
range
modules
{
switch
m
.
Type
{
case
"sentence_transformers.models.Pooling"
:
pooling
=
m
.
Path
case
"sentence_transformers.models.Normalize"
:
p
.
normalizeEmbeddings
=
true
}
}
if
pooling
!=
""
{
bts
,
err
:=
fs
.
ReadFile
(
fsys
,
filepath
.
Join
(
pooling
,
"config.json"
))
if
err
!=
nil
{
return
err
}
var
pc
struct
{
PoolingModeCLSToken
bool
`json:"pooling_mode_cls_token"`
PoolingModeMeanTokens
bool
`json:"pooling_mode_mean_tokens"`
}
if
err
:=
json
.
Unmarshal
(
bts
,
&
pc
);
err
!=
nil
{
return
err
}
if
pc
.
PoolingModeMeanTokens
{
p
.
PoolingType
=
1
}
else
if
pc
.
PoolingModeCLSToken
{
p
.
PoolingType
=
2
}
}
return
nil
}
func
(
p
*
nomicbertModel
)
KV
(
t
*
Tokenizer
)
ggml
.
KV
{
kv
:=
p
.
ModelParameters
.
KV
(
t
)
// Determine architecture based on MoE parameters (following qwen3 pattern)
arch
:=
"nomic-bert"
if
p
.
MoEEveryNLayers
>
0
{
arch
+=
"-moe"
}
kv
[
"general.architecture"
]
=
arch
kv
[
"attention.causal"
]
=
false
kv
[
"pooling_type"
]
=
p
.
PoolingType
kv
[
"normalize_embeddings"
]
=
p
.
normalizeEmbeddings
kv
[
"block_count"
]
=
cmp
.
Or
(
p
.
NLayers
,
p
.
NumHiddenLayers
)
if
contextLength
:=
p
.
MaxPositionEmbeddings
;
contextLength
>
0
{
kv
[
"context_length"
]
=
contextLength
}
if
embeddingLength
:=
p
.
HiddenSize
;
embeddingLength
>
0
{
kv
[
"embedding_length"
]
=
p
.
HiddenSize
}
if
feedForwardLength
:=
p
.
IntermediateSize
;
feedForwardLength
>
0
{
kv
[
"feed_forward_length"
]
=
p
.
IntermediateSize
}
if
headCount
:=
p
.
NumAttentionHeads
;
headCount
>
0
{
kv
[
"attention.head_count"
]
=
p
.
NumAttentionHeads
}
if
kvHeadCount
:=
p
.
NumKeyValueHeads
;
kvHeadCount
>
0
{
kv
[
"attention.head_count_kv"
]
=
p
.
NumKeyValueHeads
}
if
layerNormEpsilon
:=
cmp
.
Or
(
p
.
LayerNormEPS
,
p
.
LayerNormEpsilon
);
layerNormEpsilon
>
0
{
kv
[
"attention.layer_norm_epsilon"
]
=
layerNormEpsilon
}
if
p
.
RopeFreqBase
>
0
{
kv
[
"rope.freq_base"
]
=
p
.
RopeFreqBase
}
// MoE specific parameters (only if MoE is enabled)
if
p
.
NumExperts
>
0
{
kv
[
"expert_count"
]
=
p
.
NumExperts
}
if
p
.
NumExpertsUsed
>
0
{
kv
[
"expert_used_count"
]
=
p
.
NumExpertsUsed
}
if
p
.
MoEEveryNLayers
>
0
{
kv
[
"moe_every_n_layers"
]
=
p
.
MoEEveryNLayers
}
kv
[
"tokenizer.ggml.model"
]
=
"bert"
kv
[
"tokenizer.ggml.token_type_count"
]
=
uint32
(
2
)
// convert to phantom space tokens
for
i
,
e
:=
range
t
.
Tokens
{
switch
{
case
strings
.
HasPrefix
(
e
,
"["
)
&&
strings
.
HasSuffix
(
e
,
"]"
)
:
// noop - keep special tokens as-is
case
strings
.
HasPrefix
(
e
,
"##"
)
:
t
.
Tokens
[
i
]
=
e
[
2
:
]
default
:
t
.
Tokens
[
i
]
=
"
\u2581
"
+
e
}
}
kv
[
"tokenizer.ggml.tokens"
]
=
t
.
Tokens
return
kv
}
func
(
p
*
nomicbertModel
)
Tensors
(
ts
[]
Tensor
)
[]
*
ggml
.
Tensor
{
out
:=
make
([]
*
ggml
.
Tensor
,
0
,
len
(
ts
))
for
_
,
t
:=
range
ts
{
if
slices
.
Contains
([]
string
{
"embeddings.position_ids"
,
"pooler.dense.weight"
,
"pooler.dense.bias"
,
},
t
.
Name
())
{
continue
}
out
=
append
(
out
,
&
ggml
.
Tensor
{
Name
:
t
.
Name
(),
Kind
:
t
.
Kind
(),
Shape
:
t
.
Shape
(),
WriterTo
:
t
,
})
}
return
out
}
func
(
nomicbertModel
)
Replacements
()
[]
string
{
return
[]
string
{
"encoder.layer"
,
"blk"
,
"encoder.layers"
,
"blk"
,
"embeddings.word_embeddings"
,
"token_embd"
,
"embeddings.token_type_embeddings"
,
"token_types"
,
"embeddings.LayerNorm"
,
"token_embd_norm"
,
"attention.self.qkv"
,
"attn_qkv"
,
"attention.output.dense"
,
"attn_output"
,
"attention.output.LayerNorm"
,
"attn_output_norm"
,
"mlp.up"
,
"ffn_up"
,
"mlp.down"
,
"ffn_down"
,
"mlp.router"
,
"ffn_gate_inp"
,
"mlp.experts.up"
,
"ffn_up_exps"
,
"mlp.experts.down"
,
"ffn_down_exps"
,
"intermediate.dense"
,
"ffn_up"
,
"output.dense"
,
"ffn_down"
,
"output.LayerNorm"
,
"layer_output_norm"
,
}
}
model/models/nomicbert/model.go
View file @
76f88caf
...
@@ -34,19 +34,23 @@ type Options struct {
...
@@ -34,19 +34,23 @@ type Options struct {
poolingType
pooling
.
Type
poolingType
pooling
.
Type
normalize
bool
normalize
bool
ropeFreqBase
float32
ropeFreqBase
float32
// MoE specific options (used by v2 / MoE models only)
numExperts
int
numExpertsUsed
int
moeEveryNLayers
int
}
}
func
(
o
Options
)
applyRotaryPositionEmbeddings
(
ctx
ml
.
Context
,
states
,
positions
ml
.
Tensor
)
ml
.
Tensor
{
func
(
o
Options
)
applyRotaryPositionEmbeddings
(
ctx
ml
.
Context
,
states
,
positions
ml
.
Tensor
)
ml
.
Tensor
{
return
nn
.
RoPE
(
ctx
,
states
,
positions
,
o
.
headDim
,
o
.
ropeFreqBase
,
1.0
,
rope
.
WithTypeNeoX
())
return
nn
.
RoPE
(
ctx
,
states
,
positions
,
o
.
headDim
,
o
.
ropeFreqBase
,
1.0
,
rope
.
WithTypeNeoX
())
}
}
// Single Encoder Layer
type
EncoderLayer
struct
{
type
EncoderLayer
struct
{
*
Attention
*
Attention
AttentionNorm
*
nn
.
LayerNorm
`gguf:"attn_output_norm"`
AttentionNorm
*
nn
.
LayerNorm
`gguf:"attn_output_norm"`
*
MLP
FeedForward
FeedForward
MLPNorm
*
nn
.
LayerNorm
`gguf:"layer_output_norm"`
MLPNorm
*
nn
.
LayerNorm
`gguf:"layer_output_norm"`
}
}
...
@@ -56,12 +60,63 @@ type Attention struct {
...
@@ -56,12 +60,63 @@ type Attention struct {
Output
*
nn
.
Linear
`gguf:"attn_output"`
Output
*
nn
.
Linear
`gguf:"attn_output"`
}
}
type
MLP
struct
{
type
FeedForward
interface
{
Forward
(
ml
.
Context
,
ml
.
Tensor
,
*
Options
)
ml
.
Tensor
}
type
dense
struct
{
Gate
*
nn
.
Linear
`gguf:"ffn_gate"`
Gate
*
nn
.
Linear
`gguf:"ffn_gate"`
Up
*
nn
.
Linear
`gguf:"ffn_up"`
Up
*
nn
.
Linear
`gguf:"ffn_up"`
Down
*
nn
.
Linear
`gguf:"ffn_down"`
Down
*
nn
.
Linear
`gguf:"ffn_down"`
}
}
func
(
mlp
*
dense
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
_
*
Options
)
ml
.
Tensor
{
hidden
:=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hidden
)
}
// denseGELU implements MLP with GELU activation for v2 MoE dense layers
type
denseGELU
struct
{
Up
*
nn
.
Linear
`gguf:"ffn_up"`
Down
*
nn
.
Linear
`gguf:"ffn_down"`
}
func
(
mlp
*
denseGELU
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
_
*
Options
)
ml
.
Tensor
{
return
mlp
.
Down
.
Forward
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
)
.
GELU
(
ctx
))
}
// sparse implements MoE with expert routing
type
sparse
struct
{
Router
*
nn
.
Linear
`gguf:"ffn_gate_inp"`
Up
*
nn
.
LinearBatch
`gguf:"ffn_up_exps"`
Down
*
nn
.
LinearBatch
`gguf:"ffn_down_exps"`
}
func
(
moe
*
sparse
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
Options
)
ml
.
Tensor
{
hiddenDim
,
sequenceLength
,
batchSize
:=
hiddenStates
.
Dim
(
0
),
hiddenStates
.
Dim
(
1
),
hiddenStates
.
Dim
(
2
)
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
hiddenDim
,
sequenceLength
*
batchSize
)
routerLogits
:=
moe
.
Router
.
Forward
(
ctx
,
hiddenStates
)
routingWeights
:=
routerLogits
.
Softmax
(
ctx
)
selectedExperts
:=
routingWeights
.
TopK
(
ctx
,
opts
.
numExpertsUsed
)
routingWeights
=
routingWeights
.
Reshape
(
ctx
,
1
,
opts
.
numExperts
,
hiddenStates
.
Dim
(
1
))
.
Rows
(
ctx
,
selectedExperts
)
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
hiddenStates
.
Dim
(
0
),
1
,
hiddenStates
.
Dim
(
1
))
hiddenStates
=
moe
.
Up
.
Forward
(
ctx
,
hiddenStates
,
selectedExperts
)
.
GELU
(
ctx
)
experts
:=
moe
.
Down
.
Forward
(
ctx
,
hiddenStates
,
selectedExperts
)
experts
=
experts
.
Mul
(
ctx
,
routingWeights
)
nextStates
:=
experts
.
View
(
ctx
,
0
,
experts
.
Dim
(
0
),
experts
.
Stride
(
2
),
experts
.
Dim
(
2
))
for
i
:=
1
;
i
<
opts
.
numExpertsUsed
;
i
++
{
nextStates
=
nextStates
.
Add
(
ctx
,
experts
.
View
(
ctx
,
i
*
experts
.
Stride
(
1
),
experts
.
Dim
(
0
),
experts
.
Stride
(
2
),
experts
.
Dim
(
2
)))
}
return
nextStates
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
...
@@ -92,7 +147,7 @@ func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions
...
@@ -92,7 +147,7 @@ func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions
hiddenStates
=
e
.
AttentionNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
hiddenStates
=
e
.
AttentionNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
residual
=
hiddenStates
residual
=
hiddenStates
hiddenStates
=
e
.
MLP
.
Forward
(
ctx
,
hiddenStates
)
hiddenStates
=
e
.
FeedForward
.
Forward
(
ctx
,
hiddenStates
,
opts
)
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
residual
)
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
residual
)
hiddenStates
=
e
.
MLPNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
hiddenStates
=
e
.
MLPNorm
.
Forward
(
ctx
,
hiddenStates
,
opts
.
eps
)
...
@@ -118,12 +173,6 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
...
@@ -118,12 +173,6 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
return
a
.
Output
.
Forward
(
ctx
,
attention
)
return
a
.
Output
.
Forward
(
ctx
,
attention
)
}
}
func
(
m
*
MLP
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
)
ml
.
Tensor
{
hidden
:=
m
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
,
m
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
m
.
Down
.
Forward
(
ctx
,
hidden
)
}
func
New
(
c
fs
.
Config
)
(
model
.
Model
,
error
)
{
func
New
(
c
fs
.
Config
)
(
model
.
Model
,
error
)
{
hiddenSize
:=
int
(
c
.
Uint
(
"embedding_length"
))
hiddenSize
:=
int
(
c
.
Uint
(
"embedding_length"
))
numHeads
:=
int
(
c
.
Uint
(
"attention.head_count"
))
numHeads
:=
int
(
c
.
Uint
(
"attention.head_count"
))
...
@@ -152,17 +201,37 @@ func New(c fs.Config) (model.Model, error) {
...
@@ -152,17 +201,37 @@ func New(c fs.Config) (model.Model, error) {
false
,
false
,
)
)
blockCount
:=
int
(
c
.
Uint
(
"block_count"
))
moeEveryNLayers
:=
int
(
c
.
Uint
(
"moe_every_n_layers"
,
0
))
layers
:=
make
([]
EncoderLayer
,
blockCount
)
for
i
:=
range
layers
{
if
moeEveryNLayers
>
0
{
// Layer uses MoE if (i+1) % moe_every_n_layers == 0
if
(
i
+
1
)
%
moeEveryNLayers
==
0
{
layers
[
i
]
.
FeedForward
=
&
sparse
{}
}
else
{
layers
[
i
]
.
FeedForward
=
&
denseGELU
{}
}
}
else
{
layers
[
i
]
.
FeedForward
=
&
dense
{}
}
}
return
&
Model
{
return
&
Model
{
TextProcessor
:
processor
,
TextProcessor
:
processor
,
Layers
:
make
([]
EncoderLayer
,
c
.
Uint
(
"block_count"
))
,
Layers
:
layers
,
Options
:
Options
{
Options
:
Options
{
hiddenSize
:
hiddenSize
,
hiddenSize
:
hiddenSize
,
numHeads
:
numHeads
,
numHeads
:
numHeads
,
headDim
:
headDim
,
headDim
:
headDim
,
eps
:
c
.
Float
(
"attention.layer_norm_epsilon"
),
eps
:
c
.
Float
(
"attention.layer_norm_epsilon"
),
poolingType
:
pooling
.
Type
(
c
.
Uint
(
"pooling_type"
)),
poolingType
:
pooling
.
Type
(
c
.
Uint
(
"pooling_type"
)),
normalize
:
c
.
Bool
(
"normalize_embeddings"
,
false
),
normalize
:
c
.
Bool
(
"normalize_embeddings"
,
false
),
ropeFreqBase
:
c
.
Float
(
"rope.freq_base"
,
1000.0
),
ropeFreqBase
:
c
.
Float
(
"rope.freq_base"
,
1000.0
),
numExperts
:
int
(
c
.
Uint
(
"expert_count"
)),
numExpertsUsed
:
int
(
c
.
Uint
(
"expert_used_count"
)),
moeEveryNLayers
:
moeEveryNLayers
,
},
},
},
nil
},
nil
}
}
...
@@ -170,4 +239,6 @@ func New(c fs.Config) (model.Model, error) {
...
@@ -170,4 +239,6 @@ func New(c fs.Config) (model.Model, error) {
func
init
()
{
func
init
()
{
model
.
Register
(
"nomic-bert"
,
New
)
model
.
Register
(
"nomic-bert"
,
New
)
model
.
Register
(
"nomic-bert_embed"
,
New
)
model
.
Register
(
"nomic-bert_embed"
,
New
)
model
.
Register
(
"nomic-bert-moe"
,
New
)
model
.
Register
(
"nomic-bert-moe_embed"
,
New
)
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment