Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
9ed8bf14
Unverified
Commit
9ed8bf14
authored
May 20, 2025
by
Michael Yang
Committed by
GitHub
May 20, 2025
Browse files
ml: add more rope options (#10775)
parent
e6a800ca
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
116 additions
and
83 deletions
+116
-83
ml/backend.go
ml/backend.go
+0
-16
ml/backend/ggml/ggml.go
ml/backend/ggml/ggml.go
+7
-19
ml/nn/fast/rope.go
ml/nn/fast/rope.go
+21
-0
ml/nn/rope/rope.go
ml/nn/rope/rope.go
+33
-0
model/models/gemma2/model.go
model/models/gemma2/model.go
+5
-4
model/models/gemma3/model_text.go
model/models/gemma3/model_text.go
+5
-4
model/models/llama/model.go
model/models/llama/model.go
+9
-8
model/models/llama4/model_text.go
model/models/llama4/model_text.go
+5
-3
model/models/mistral3/model_text.go
model/models/mistral3/model_text.go
+8
-8
model/models/mllama/model_text.go
model/models/mllama/model_text.go
+7
-6
model/models/qwen25vl/model_text.go
model/models/qwen25vl/model_text.go
+16
-15
No files found.
ml/backend.go
View file @
9ed8bf14
...
@@ -115,21 +115,6 @@ type Context interface {
...
@@ -115,21 +115,6 @@ type Context interface {
Layer
(
int
)
Context
Layer
(
int
)
Context
}
}
// RopeOptions contains optional parameters for RoPE function
type
RopeOptions
struct
{
OriginalContextLen
uint32
}
// RopeOption defines a function that modifies RopeOpts
type
RopeOption
func
(
*
RopeOptions
)
// WithContextLen sets a custom context length
func
WithContextLen
(
len
uint32
)
RopeOption
{
return
func
(
opts
*
RopeOptions
)
{
opts
.
OriginalContextLen
=
len
}
}
type
Tensor
interface
{
type
Tensor
interface
{
Dim
(
n
int
)
int
Dim
(
n
int
)
int
Stride
(
n
int
)
int
Stride
(
n
int
)
int
...
@@ -155,7 +140,6 @@ type Tensor interface {
...
@@ -155,7 +140,6 @@ type Tensor interface {
AvgPool2D
(
ctx
Context
,
k
,
s
int
,
p
float32
)
Tensor
AvgPool2D
(
ctx
Context
,
k
,
s
int
,
p
float32
)
Tensor
Conv2D
(
ctx
Context
,
weight
Tensor
,
s0
,
s1
,
p0
,
p1
,
d0
,
d1
int
)
Tensor
Conv2D
(
ctx
Context
,
weight
Tensor
,
s0
,
s1
,
p0
,
p1
,
d0
,
d1
int
)
Tensor
RoPE
(
ctx
Context
,
positionIDs
,
ropeFactors
Tensor
,
dim
,
ropeType
uint32
,
base
,
scale
float32
,
options
...
RopeOption
)
Tensor
IM2Col
(
ctx
Context
,
weight
Tensor
,
s0
,
s1
,
p0
,
p1
,
d0
,
d1
int
)
Tensor
IM2Col
(
ctx
Context
,
weight
Tensor
,
s0
,
s1
,
p0
,
p1
,
d0
,
d1
int
)
Tensor
Sin
(
ctx
Context
)
Tensor
Sin
(
ctx
Context
)
Tensor
...
...
ml/backend/ggml/ggml.go
View file @
9ed8bf14
...
@@ -30,6 +30,7 @@ import (
...
@@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
ggml
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
ggml
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"github.com/ollama/ollama/ml/nn/rope"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/errgroup"
)
)
...
@@ -1074,28 +1075,15 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
...
@@ -1074,28 +1075,15 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
}
}
}
const
(
func
(
t
*
Tensor
)
RoPE
(
ctx
ml
.
Context
,
positions
ml
.
Tensor
,
ropeDim
int
,
ropeBase
,
ropeScale
float32
,
options
...
func
(
*
rope
.
Options
))
ml
.
Tensor
{
ropeTypeNorm
C
.
int
=
0
ropeTypeNeox
C
.
int
=
2
ropeTypeMrope
C
.
int
=
8
ropeTypeVision
C
.
int
=
24
)
func
(
t
*
Tensor
)
RoPE
(
ctx
ml
.
Context
,
positionIDs
,
ropeFactors
ml
.
Tensor
,
ropeDim
,
ropeType
uint32
,
ropeBase
,
ropeScale
float32
,
options
...
ml
.
RopeOption
)
ml
.
Tensor
{
// Default options
// Default options
opts
:=
&
ml
.
RopeOptions
{
opts
:=
&
rope
.
Options
{
OriginalContextLength
:
131072
,
Factors
:
&
Tensor
{}}
OriginalContextLen
:
131072
,
}
// Apply any provided options
// Apply any provided options
for
_
,
option
:=
range
options
{
for
_
,
option
:=
range
options
{
option
(
opts
)
option
(
opts
)
}
}
if
ropeFactors
==
nil
{
ropeFactors
=
&
Tensor
{
b
:
t
.
b
}
}
dequant
:=
t
.
t
dequant
:=
t
.
t
if
C
.
ggml_is_quantized
(
t
.
t
.
_type
)
{
if
C
.
ggml_is_quantized
(
t
.
t
.
_type
)
{
dequant
=
C
.
ggml_cast
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
GGML_TYPE_F32
)
dequant
=
C
.
ggml_cast
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
GGML_TYPE_F32
)
...
@@ -1106,11 +1094,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
...
@@ -1106,11 +1094,11 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
t
:
C
.
ggml_rope_ext
(
t
:
C
.
ggml_rope_ext
(
ctx
.
(
*
Context
)
.
ctx
,
ctx
.
(
*
Context
)
.
ctx
,
dequant
,
dequant
,
position
ID
s
.
(
*
Tensor
)
.
t
,
positions
.
(
*
Tensor
)
.
t
,
r
op
e
Factors
.
(
*
Tensor
)
.
t
,
op
ts
.
Factors
.
(
*
Tensor
)
.
t
,
C
.
int
(
ropeDim
),
C
.
int
(
ropeDim
),
C
.
int
(
r
op
e
Type
),
C
.
int
(
op
ts
.
Type
),
C
.
int
(
opts
.
OriginalContextLen
),
C
.
int
(
opts
.
OriginalContextLen
gth
),
C
.
float
(
ropeBase
),
C
.
float
(
ropeBase
),
C
.
float
(
ropeScale
),
C
.
float
(
ropeScale
),
C
.
float
(
0.0
),
C
.
float
(
0.0
),
...
...
ml/nn/fast/rope.go
0 → 100644
View file @
9ed8bf14
// fast provides implementations of fast (fused) operations for increased performance.
package
fast
import
(
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/rope"
)
// fastRoPE is an interface for tensors that support fast rotary positional embedding.
type
fastRoPE
interface
{
RoPE
(
ctx
ml
.
Context
,
positionIDs
ml
.
Tensor
,
dim
int
,
base
,
scale
float32
,
options
...
func
(
*
rope
.
Options
))
ml
.
Tensor
}
// RoPE applies rotary positional embedding to tensor `t`.
func
RoPE
(
ctx
ml
.
Context
,
t
,
positions
ml
.
Tensor
,
dim
int
,
base
,
scale
float32
,
options
...
func
(
*
rope
.
Options
))
ml
.
Tensor
{
if
t
,
ok
:=
t
.
(
fastRoPE
);
ok
{
return
t
.
RoPE
(
ctx
,
positions
,
dim
,
base
,
scale
,
options
...
)
}
panic
(
"RoPE not implemented for this tensor type"
)
}
ml/nn/rope/rope.go
0 → 100644
View file @
9ed8bf14
package
rope
import
"github.com/ollama/ollama/ml"
// Options contains optional parameters for RoPE function
type
Options
struct
{
OriginalContextLength
int
Type
int
Factors
ml
.
Tensor
}
// WithOriginalContextLength sets a custom context length
func
WithOriginalContextLength
(
n
int
)
func
(
*
Options
)
{
return
func
(
opts
*
Options
)
{
opts
.
OriginalContextLength
=
n
}
}
// WithType sets RoPE type to NeoX
func
WithTypeNeoX
()
func
(
*
Options
)
{
return
func
(
opts
*
Options
)
{
opts
.
Type
=
2
}
}
// WithFactors sets custom rope factors
func
WithFactors
(
factors
ml
.
Tensor
)
func
(
*
Options
)
{
return
func
(
opts
*
Options
)
{
if
factors
!=
nil
{
opts
.
Factors
=
factors
}
}
}
model/models/gemma2/model.go
View file @
9ed8bf14
...
@@ -7,6 +7,8 @@ import (
...
@@ -7,6 +7,8 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/input"
)
)
...
@@ -83,11 +85,10 @@ type SelfAttention struct {
...
@@ -83,11 +85,10 @@ type SelfAttention struct {
func
(
sa
*
SelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
Options
)
ml
.
Tensor
{
func
(
sa
*
SelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
Options
)
ml
.
Tensor
{
batchSize
:=
hiddenState
.
Dim
(
1
)
batchSize
:=
hiddenState
.
Dim
(
1
)
ropeType
:=
uint32
(
2
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
=
q
.
Reshape
(
ctx
,
opts
.
attnKeyLen
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
Reshape
(
ctx
,
opts
.
attnKeyLen
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
RoPE
(
ctx
,
positionIDs
,
nil
,
uint32
(
opts
.
attnKeyLen
),
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
q
=
fast
.
RoPE
(
ctx
,
q
,
positionIDs
,
opts
.
attnKeyLen
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
WithTypeNeoX
()
)
if
opts
.
largeModelScaling
{
if
opts
.
largeModelScaling
{
q
=
q
.
Scale
(
ctx
,
1.0
/
math
.
Sqrt
(
float64
(
opts
.
hiddenSize
/
opts
.
numHeads
)))
q
=
q
.
Scale
(
ctx
,
1.0
/
math
.
Sqrt
(
float64
(
opts
.
hiddenSize
/
opts
.
numHeads
)))
...
@@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
...
@@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
=
k
.
Reshape
(
ctx
,
opts
.
attnKeyLen
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
Reshape
(
ctx
,
opts
.
attnKeyLen
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
RoPE
(
ctx
,
positionIDs
,
nil
,
uint32
(
opts
.
attnKeyLen
),
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
k
=
fast
.
RoPE
(
ctx
,
k
,
positionIDs
,
opts
.
attnKeyLen
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
WithTypeNeoX
()
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
=
v
.
Reshape
(
ctx
,
opts
.
attnValLen
,
opts
.
numKVHeads
,
batchSize
)
v
=
v
.
Reshape
(
ctx
,
opts
.
attnValLen
,
opts
.
numKVHeads
,
batchSize
)
...
@@ -127,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
...
@@ -127,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
}
func
(
m
*
Model
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
return
key
.
RoPE
(
ctx
,
shift
,
nil
,
uint32
(
m
.
Options
.
attnKeyLen
),
uint32
(
2
)
,
m
.
Options
.
ropeBase
,
m
.
Options
.
ropeScale
),
nil
return
fast
.
RoPE
(
ctx
,
key
,
shift
,
m
.
Options
.
attnKeyLen
,
m
.
Options
.
ropeBase
,
m
.
Options
.
ropeScale
,
rope
.
WithTypeNeoX
()
),
nil
}
}
type
MLP
struct
{
type
MLP
struct
{
...
...
model/models/gemma3/model_text.go
View file @
9ed8bf14
...
@@ -7,6 +7,8 @@ import (
...
@@ -7,6 +7,8 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/input"
)
)
...
@@ -73,7 +75,6 @@ type TextSelfAttention struct {
...
@@ -73,7 +75,6 @@ type TextSelfAttention struct {
func
(
sa
*
TextSelfAttention
)
Forward
(
ctx
ml
.
Context
,
layer
int
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
TextConfig
)
ml
.
Tensor
{
func
(
sa
*
TextSelfAttention
)
Forward
(
ctx
ml
.
Context
,
layer
int
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
TextConfig
)
ml
.
Tensor
{
batchSize
:=
hiddenState
.
Dim
(
1
)
batchSize
:=
hiddenState
.
Dim
(
1
)
ropeType
:=
uint32
(
2
)
ropeBase
:=
opts
.
ropeLocalBase
ropeBase
:=
opts
.
ropeLocalBase
if
(
layer
+
1
)
%
gemmaGlobalCacheCount
==
0
{
if
(
layer
+
1
)
%
gemmaGlobalCacheCount
==
0
{
...
@@ -83,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
...
@@ -83,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
=
q
.
Reshape
(
ctx
,
opts
.
attnKeyLen
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
Reshape
(
ctx
,
opts
.
attnKeyLen
,
opts
.
numHeads
,
batchSize
)
q
=
sa
.
QueryNorm
.
Forward
(
ctx
,
q
,
opts
.
eps
)
q
=
sa
.
QueryNorm
.
Forward
(
ctx
,
q
,
opts
.
eps
)
q
=
q
.
RoPE
(
ctx
,
positionIDs
,
nil
,
uint32
(
opts
.
attnKeyLen
),
ropeType
,
ropeBase
,
opts
.
ropeScale
)
q
=
fast
.
RoPE
(
ctx
,
q
,
positionIDs
,
opts
.
attnKeyLen
,
ropeBase
,
opts
.
ropeScale
,
rope
.
WithTypeNeoX
()
)
if
opts
.
largeModelScaling
{
if
opts
.
largeModelScaling
{
q
=
q
.
Scale
(
ctx
,
1.0
/
math
.
Sqrt
(
float64
(
opts
.
hiddenSize
/
opts
.
numHeads
)))
q
=
q
.
Scale
(
ctx
,
1.0
/
math
.
Sqrt
(
float64
(
opts
.
hiddenSize
/
opts
.
numHeads
)))
...
@@ -94,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
...
@@ -94,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
=
k
.
Reshape
(
ctx
,
opts
.
attnKeyLen
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
Reshape
(
ctx
,
opts
.
attnKeyLen
,
opts
.
numKVHeads
,
batchSize
)
k
=
sa
.
KeyNorm
.
Forward
(
ctx
,
k
,
opts
.
eps
)
k
=
sa
.
KeyNorm
.
Forward
(
ctx
,
k
,
opts
.
eps
)
k
=
k
.
RoPE
(
ctx
,
positionIDs
,
nil
,
uint32
(
opts
.
attnKeyLen
),
ropeType
,
ropeBase
,
opts
.
ropeScale
)
k
=
fast
.
RoPE
(
ctx
,
k
,
positionIDs
,
opts
.
attnKeyLen
,
ropeBase
,
opts
.
ropeScale
,
rope
.
WithTypeNeoX
()
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
=
v
.
Reshape
(
ctx
,
opts
.
attnValLen
,
opts
.
numKVHeads
,
batchSize
)
v
=
v
.
Reshape
(
ctx
,
opts
.
attnValLen
,
opts
.
numKVHeads
,
batchSize
)
...
@@ -112,7 +113,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
...
@@ -112,7 +113,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase
=
m
.
TextConfig
.
ropeGlobalBase
ropeBase
=
m
.
TextConfig
.
ropeGlobalBase
}
}
return
key
.
RoPE
(
ctx
,
shift
,
nil
,
uint32
(
m
.
TextConfig
.
attnKeyLen
),
uint32
(
2
)
,
ropeBase
,
m
.
TextConfig
.
ropeScale
),
nil
return
fast
.
RoPE
(
ctx
,
key
,
shift
,
m
.
TextConfig
.
attnKeyLen
,
ropeBase
,
m
.
TextConfig
.
ropeScale
,
rope
.
WithTypeNeoX
()
),
nil
}
}
type
TextMLP
struct
{
type
TextMLP
struct
{
...
...
model/models/llama/model.go
View file @
9ed8bf14
...
@@ -8,14 +8,16 @@ import (
...
@@ -8,14 +8,16 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/input"
)
)
type
Options
struct
{
type
Options
struct
{
hiddenSize
,
numHeads
,
numKVHeads
,
headDim
int
hiddenSize
,
numHeads
,
numKVHeads
int
eps
,
ropeBase
,
ropeScale
float32
headDim
,
ropeDim
int
ropeDim
uin
t32
eps
,
ropeBase
,
ropeScale
floa
t32
}
}
type
Model
struct
{
type
Model
struct
{
...
@@ -53,10 +55,10 @@ func New(c fs.Config) (model.Model, error) {
...
@@ -53,10 +55,10 @@ func New(c fs.Config) (model.Model, error) {
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
headDim
:
int
(
c
.
Uint
(
"attention.key_length"
)),
headDim
:
int
(
c
.
Uint
(
"attention.key_length"
)),
ropeDim
:
int
(
c
.
Uint
(
"rope.dimension_count"
)),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
ropeDim
:
c
.
Uint
(
"rope.dimension_count"
),
},
},
}
}
...
@@ -76,15 +78,14 @@ type SelfAttention struct {
...
@@ -76,15 +78,14 @@ type SelfAttention struct {
func
(
sa
*
SelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
Options
)
ml
.
Tensor
{
func
(
sa
*
SelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
Options
)
ml
.
Tensor
{
batchSize
:=
hiddenState
.
Dim
(
1
)
batchSize
:=
hiddenState
.
Dim
(
1
)
headDim
:=
cmp
.
Or
(
opts
.
headDim
,
opts
.
hiddenSize
/
opts
.
numHeads
)
headDim
:=
cmp
.
Or
(
opts
.
headDim
,
opts
.
hiddenSize
/
opts
.
numHeads
)
ropeType
:=
uint32
(
0
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
=
q
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
RoPE
(
ctx
,
positionIDs
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
q
=
fast
.
RoPE
(
ctx
,
q
,
positionIDs
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
WithFactors
(
sa
.
RopeFactors
)
)
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
=
k
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
RoPE
(
ctx
,
positionIDs
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
k
=
fast
.
RoPE
(
ctx
,
k
,
positionIDs
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
WithFactors
(
sa
.
RopeFactors
)
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
=
v
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
v
=
v
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
...
@@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
...
@@ -97,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
}
func
(
m
*
Model
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
return
key
.
RoPE
(
ctx
,
shift
,
m
.
Layers
[
layer
]
.
SelfAttention
.
RopeFactors
,
uint32
(
0
),
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
),
nil
return
fast
.
RoPE
(
ctx
,
key
,
shift
,
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
,
rope
.
WithFactors
(
m
.
Layers
[
layer
]
.
SelfAttention
.
RopeFactors
)
),
nil
}
}
type
MLP
struct
{
type
MLP
struct
{
...
...
model/models/llama4/model_text.go
View file @
9ed8bf14
...
@@ -8,6 +8,8 @@ import (
...
@@ -8,6 +8,8 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/input"
)
)
...
@@ -31,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
...
@@ -31,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
value
=
value
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
value
=
value
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
if
useRope
{
if
useRope
{
query
=
query
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
uint32
(
opts
.
ropeDim
),
uint32
(
0
),
opts
.
ropeBase
,
opts
.
ropeScale
)
query
=
fast
.
RoPE
(
ctx
,
query
,
positions
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
WithFactors
(
sa
.
RopeFactors
)
)
key
=
key
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
uint32
(
opts
.
ropeDim
),
uint32
(
0
),
opts
.
ropeBase
,
opts
.
ropeScale
)
key
=
fast
.
RoPE
(
ctx
,
key
,
positions
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
WithFactors
(
sa
.
RopeFactors
)
)
}
}
if
opts
.
useQKNorm
{
if
opts
.
useQKNorm
{
...
@@ -250,5 +252,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
...
@@ -250,5 +252,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
}
}
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
return
key
.
RoPE
(
ctx
,
shift
,
m
.
Layers
[
layer
]
.
Attention
.
RopeFactors
,
uint32
(
0
),
uint32
(
m
.
ropeDim
),
m
.
ropeBase
,
m
.
ropeScale
),
nil
return
fast
.
RoPE
(
ctx
,
key
,
shift
,
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
,
rope
.
WithFactors
(
m
.
Layers
[
layer
]
.
Attention
.
RopeFactors
)
),
nil
}
}
model/models/mistral3/model_text.go
View file @
9ed8bf14
...
@@ -8,13 +8,14 @@ import (
...
@@ -8,13 +8,14 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/input"
)
)
type
TextOptions
struct
{
type
TextOptions
struct
{
hiddenSize
,
numHeads
,
numKVHeads
,
headDim
int
hiddenSize
,
numHeads
,
numKVHeads
int
eps
,
ropeBase
,
ropeScale
float32
headDim
,
ropeDim
int
ropeDim
uin
t32
eps
,
ropeBase
,
ropeScale
floa
t32
}
}
type
TextModel
struct
{
type
TextModel
struct
{
...
@@ -35,16 +36,15 @@ type SelfAttention struct {
...
@@ -35,16 +36,15 @@ type SelfAttention struct {
func
(
sa
*
SelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
TextOptions
)
ml
.
Tensor
{
func
(
sa
*
SelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positionIDs
ml
.
Tensor
,
cache
kvcache
.
Cache
,
opts
*
TextOptions
)
ml
.
Tensor
{
batchSize
:=
hiddenState
.
Dim
(
1
)
batchSize
:=
hiddenState
.
Dim
(
1
)
ropeType
:=
uint32
(
0
)
headDim
:=
cmp
.
Or
(
opts
.
headDim
,
opts
.
hiddenSize
/
opts
.
numHeads
)
headDim
:=
cmp
.
Or
(
opts
.
headDim
,
opts
.
hiddenSize
/
opts
.
numHeads
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
=
q
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
RoPE
(
ctx
,
positionIDs
,
nil
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
q
=
fast
.
RoPE
(
ctx
,
q
,
positionIDs
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
)
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
=
k
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
RoPE
(
ctx
,
positionIDs
,
nil
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
k
=
fast
.
RoPE
(
ctx
,
k
,
positionIDs
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
=
v
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
v
=
v
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
...
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
...
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
}
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
return
key
.
RoPE
(
ctx
,
shift
,
nil
,
uint32
(
0
)
,
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
),
nil
return
fast
.
RoPE
(
ctx
,
key
,
shift
,
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
),
nil
}
}
type
MLP
struct
{
type
MLP
struct
{
...
@@ -129,10 +129,10 @@ func newTextModel(c fs.Config) *TextModel {
...
@@ -129,10 +129,10 @@ func newTextModel(c fs.Config) *TextModel {
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
headDim
:
int
(
c
.
Uint
(
"attention.key_length"
)),
headDim
:
int
(
c
.
Uint
(
"attention.key_length"
)),
ropeDim
:
int
(
c
.
Uint
(
"rope.dimension_count"
)),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
ropeDim
:
c
.
Uint
(
"rope.dimension_count"
),
},
},
}
}
}
}
model/models/mllama/model_text.go
View file @
9ed8bf14
...
@@ -8,6 +8,8 @@ import (
...
@@ -8,6 +8,8 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
)
type
TextSelfAttention
struct
{
type
TextSelfAttention
struct
{
...
@@ -21,15 +23,14 @@ type TextSelfAttention struct {
...
@@ -21,15 +23,14 @@ type TextSelfAttention struct {
func
(
sa
*
TextSelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positions
ml
.
Tensor
,
cache
*
kvcache
.
WrapperCache
,
opts
*
TextModelOptions
)
ml
.
Tensor
{
func
(
sa
*
TextSelfAttention
)
Forward
(
ctx
ml
.
Context
,
hiddenState
,
positions
ml
.
Tensor
,
cache
*
kvcache
.
WrapperCache
,
opts
*
TextModelOptions
)
ml
.
Tensor
{
batchSize
:=
hiddenState
.
Dim
(
1
)
batchSize
:=
hiddenState
.
Dim
(
1
)
headDim
:=
opts
.
hiddenSize
/
opts
.
numHeads
headDim
:=
opts
.
hiddenSize
/
opts
.
numHeads
ropeType
:=
uint32
(
0
)
query
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
query
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
query
=
query
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
query
=
query
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
query
=
query
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
query
=
fast
.
RoPE
(
ctx
,
query
,
positions
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
WithFactors
(
sa
.
RopeFactors
)
)
key
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
key
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
key
=
key
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
key
=
key
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
key
=
key
.
RoPE
(
ctx
,
positions
,
sa
.
RopeFactors
,
opts
.
ropeDim
,
ropeType
,
opts
.
ropeBase
,
opts
.
ropeScale
)
key
=
fast
.
RoPE
(
ctx
,
key
,
positions
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
WithFactors
(
sa
.
RopeFactors
)
)
value
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
value
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
value
=
value
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
value
=
value
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
...
@@ -44,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
...
@@ -44,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
// This will only get called for layers in the cache, which are just the self attention layers
// This will only get called for layers in the cache, which are just the self attention layers
if
sa
,
ok
:=
m
.
Transformer
.
Layers
[
layer
]
.
(
*
TextSelfAttentionDecoderLayer
);
ok
{
if
sa
,
ok
:=
m
.
Transformer
.
Layers
[
layer
]
.
(
*
TextSelfAttentionDecoderLayer
);
ok
{
return
key
.
RoPE
(
ctx
,
shift
,
sa
.
SelfAttention
.
RopeFactors
,
m
.
ropeDim
,
uint32
(
0
),
m
.
ropeBase
,
m
.
ropeScale
),
nil
return
fast
.
RoPE
(
ctx
,
key
,
shift
,
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
,
rope
.
WithFactors
(
sa
.
SelfAttention
.
RopeFactors
)
),
nil
}
}
return
key
,
nil
return
key
,
nil
...
@@ -199,8 +200,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
...
@@ -199,8 +200,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
type
TextModelOptions
struct
{
type
TextModelOptions
struct
{
hiddenSize
,
numHeads
,
numKVHeads
int
hiddenSize
,
numHeads
,
numKVHeads
int
ropeDim
int
eps
,
ropeBase
,
ropeScale
float32
eps
,
ropeBase
,
ropeScale
float32
ropeDim
uint32
crossAttentionLayers
[]
int32
crossAttentionLayers
[]
int32
}
}
...
@@ -240,10 +241,10 @@ func newTextModel(c fs.Config) *TextModel {
...
@@ -240,10 +241,10 @@ func newTextModel(c fs.Config) *TextModel {
hiddenSize
:
int
(
c
.
Uint
(
"embedding_length"
)),
hiddenSize
:
int
(
c
.
Uint
(
"embedding_length"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
ropeDim
:
int
(
c
.
Uint
(
"rope.dimension_count"
)),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
ropeDim
:
c
.
Uint
(
"rope.dimension_count"
),
crossAttentionLayers
:
c
.
Ints
(
"attention.cross_attention_layers"
),
crossAttentionLayers
:
c
.
Ints
(
"attention.cross_attention_layers"
),
},
},
}
}
...
...
model/models/qwen25vl/model_text.go
View file @
9ed8bf14
...
@@ -7,13 +7,15 @@ import (
...
@@ -7,13 +7,15 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/model/input"
)
)
type
TextOptions
struct
{
type
TextOptions
struct
{
ctxLen
,
hiddenSize
,
numHeads
,
numKVHeads
int
hiddenSize
,
numHeads
,
numKVHeads
int
eps
,
ropeBase
,
ropeScale
float32
ropeDim
,
originalContextLength
int
ropeDim
,
defaultContextLen
uin
t32
eps
,
ropeBase
,
ropeScale
floa
t32
}
}
type
TextModel
struct
{
type
TextModel
struct
{
...
@@ -29,15 +31,14 @@ func NewTextModel(c fs.Config) *TextModel {
...
@@ -29,15 +31,14 @@ func NewTextModel(c fs.Config) *TextModel {
m
:=
TextModel
{
m
:=
TextModel
{
Layers
:
make
([]
Layer
,
c
.
Uint
(
"block_count"
)),
Layers
:
make
([]
Layer
,
c
.
Uint
(
"block_count"
)),
TextOptions
:
&
TextOptions
{
TextOptions
:
&
TextOptions
{
ctxLen
:
int
(
c
.
Uint
(
"context_length"
)),
hiddenSize
:
int
(
c
.
Uint
(
"embedding_length"
)),
hiddenSize
:
int
(
c
.
Uint
(
"embedding_length"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numHeads
:
int
(
c
.
Uint
(
"attention.head_count"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
numKVHeads
:
int
(
c
.
Uint
(
"attention.head_count_kv"
)),
ropeDim
:
int
(
c
.
Uint
(
"rope.dimension_count"
,
128
)),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
originalContextLength
:
int
(
c
.
Uint
(
"context_length"
,
128000
)),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
eps
:
c
.
Float
(
"attention.layer_norm_rms_epsilon"
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
ropeBase
:
c
.
Float
(
"rope.freq_base"
),
ropeDim
:
c
.
Uint
(
"rope.dimension_count"
,
128
),
ropeScale
:
c
.
Float
(
"rope.freq_scale"
,
1
),
defaultContextLen
:
c
.
Uint
(
"context_length"
,
128000
),
},
},
}
}
...
@@ -59,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
...
@@ -59,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
:=
sa
.
Query
.
Forward
(
ctx
,
hiddenState
)
q
=
q
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
Reshape
(
ctx
,
headDim
,
opts
.
numHeads
,
batchSize
)
q
=
q
.
RoPE
(
ctx
,
positionIDs
,
nil
,
opts
.
ropeDim
,
2
,
opts
.
ropeBase
,
opts
.
ropeScale
,
ml
.
WithContextLen
(
opts
.
default
ContextLen
))
q
=
fast
.
RoPE
(
ctx
,
q
,
positionIDs
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
With
Original
ContextLen
gth
(
opts
.
original
ContextLen
gth
),
rope
.
WithTypeNeoX
(
))
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
:=
sa
.
Key
.
Forward
(
ctx
,
hiddenState
)
k
=
k
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
k
=
k
.
RoPE
(
ctx
,
positionIDs
,
nil
,
opts
.
ropeDim
,
2
,
opts
.
ropeBase
,
opts
.
ropeScale
,
ml
.
WithContextLen
(
opts
.
default
ContextLen
))
k
=
fast
.
RoPE
(
ctx
,
k
,
positionIDs
,
opts
.
ropeDim
,
opts
.
ropeBase
,
opts
.
ropeScale
,
rope
.
With
Original
ContextLen
gth
(
opts
.
original
ContextLen
gth
),
rope
.
WithTypeNeoX
(
))
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
:=
sa
.
Value
.
Forward
(
ctx
,
hiddenState
)
v
=
v
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
v
=
v
.
Reshape
(
ctx
,
headDim
,
opts
.
numKVHeads
,
batchSize
)
...
@@ -77,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
...
@@ -77,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
TextModel
)
Shift
(
ctx
ml
.
Context
,
layer
int
,
key
,
shift
ml
.
Tensor
)
(
ml
.
Tensor
,
error
)
{
return
key
.
RoPE
(
ctx
,
shift
,
nil
,
m
.
ropeDim
,
2
,
m
.
ropeBase
,
m
.
ropeScale
,
ml
.
WithContextLen
(
m
.
defaultContextLen
)),
nil
return
fast
.
RoPE
(
ctx
,
key
,
shift
,
m
.
ropeDim
,
m
.
ropeBase
,
m
.
ropeScale
,
rope
.
With
Original
ContextLen
gth
(
m
.
originalContextLength
),
rope
.
WithTypeNeoX
(
)),
nil
}
}
// MLP implements the feed-forward network component with SwiGLU activation
// MLP implements the feed-forward network component with SwiGLU activation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment