Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
ad95d5b3
Unverified
Commit
ad95d5b3
authored
Sep 16, 2025
by
Michael Yang
Committed by
GitHub
Sep 16, 2025
Browse files
use split activations when possible (#12293)
* use ggml_*_split activations when possible * forward qkv
parent
c253433d
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
59 additions
and
50 deletions
+59
-50
ml/backend.go
ml/backend.go
+6
-5
ml/backend/ggml/ggml.go
ml/backend/ggml/ggml.go
+21
-10
ml/nn/attention.go
ml/nn/attention.go
+2
-0
model/models/gemma2/model.go
model/models/gemma2/model.go
+1
-1
model/models/gemma3/model_text.go
model/models/gemma3/model_text.go
+1
-1
model/models/gemma3n/model_text.go
model/models/gemma3n/model_text.go
+2
-3
model/models/gptoss/model.go
model/models/gptoss/model.go
+1
-1
model/models/llama/model.go
model/models/llama/model.go
+1
-1
model/models/llama4/model_text.go
model/models/llama4/model_text.go
+8
-8
model/models/mistral3/model_text.go
model/models/mistral3/model_text.go
+1
-1
model/models/mistral3/model_vision.go
model/models/mistral3/model_vision.go
+1
-1
model/models/mllama/model_text.go
model/models/mllama/model_text.go
+1
-1
model/models/qwen2/model.go
model/models/qwen2/model.go
+1
-1
model/models/qwen25vl/model_text.go
model/models/qwen25vl/model_text.go
+1
-1
model/models/qwen25vl/model_vision.go
model/models/qwen25vl/model_vision.go
+1
-2
model/models/qwen3/model.go
model/models/qwen3/model.go
+10
-13
No files found.
ml/backend.go
View file @
ad95d5b3
...
...
@@ -430,12 +430,13 @@ type Tensor interface {
Sin
(
ctx
Context
)
Tensor
Cos
(
ctx
Context
)
Tensor
Tanh
(
ctx
Context
)
Tensor
GELU
(
ctx
Context
)
Tensor
QuickGELU
(
ctx
Context
)
Tensor
SILU
(
ctx
Context
)
Tensor
RELU
(
ctx
Context
)
Tensor
GELU
(
ctx
Context
,
up
...
Tensor
)
Tensor
SILU
(
ctx
Context
,
up
...
Tensor
)
Tensor
RELU
(
ctx
Context
,
up
...
Tensor
)
Tensor
Sigmoid
(
ctx
Context
)
Tensor
SwiGLU
(
ctx
Context
,
up
Tensor
,
alpha
,
limit
float32
)
Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
SILUAlphaLimit
(
ctx
Context
,
up
Tensor
,
alpha
,
limit
float32
)
Tensor
Reshape
(
ctx
Context
,
shape
...
int
)
Tensor
View
(
ctx
Context
,
offset
int
,
shape
...
int
)
Tensor
...
...
ml/backend/ggml/ggml.go
View file @
ad95d5b3
...
...
@@ -1431,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
func
(
t
*
Tensor
)
GELU
(
ctx
ml
.
Context
)
ml
.
Tensor
{
func
(
t
*
Tensor
)
GELU
(
ctx
ml
.
Context
,
t2
...
ml
.
Tensor
)
ml
.
Tensor
{
if
len
(
t2
)
>
0
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_geglu_split
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
t2
[
0
]
.
(
*
Tensor
)
.
t
),
}
}
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_gelu_inplace
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
),
}
}
func
(
t
*
Tensor
)
QuickGELU
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_gelu_quick_inplace
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
),
func
(
t
*
Tensor
)
SILU
(
ctx
ml
.
Context
,
t2
...
ml
.
Tensor
)
ml
.
Tensor
{
if
len
(
t2
)
>
0
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_swiglu_split
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
t2
[
0
]
.
(
*
Tensor
)
.
t
),
}
}
}
func
(
t
*
Tensor
)
SILU
(
ctx
ml
.
Context
)
ml
.
Tensor
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_silu_inplace
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
),
}
}
func
(
t
*
Tensor
)
RELU
(
ctx
ml
.
Context
)
ml
.
Tensor
{
func
(
t
*
Tensor
)
RELU
(
ctx
ml
.
Context
,
t2
...
ml
.
Tensor
)
ml
.
Tensor
{
if
len
(
t2
)
>
0
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_reglu_split
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
t2
[
0
]
.
(
*
Tensor
)
.
t
),
}
}
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_relu_inplace
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
),
}
}
func
(
t
*
Tensor
)
S
wiGLU
(
ctx
ml
.
Context
,
up
ml
.
Tensor
,
alpha
,
limit
float32
)
ml
.
Tensor
{
func
(
t
*
Tensor
)
S
ILUAlphaLimit
(
ctx
ml
.
Context
,
up
ml
.
Tensor
,
alpha
,
limit
float32
)
ml
.
Tensor
{
return
&
Tensor
{
b
:
t
.
b
,
t
:
C
.
ggml_swiglu_oai
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
up
.
(
*
Tensor
)
.
t
,
C
.
float
(
alpha
),
C
.
float
(
limit
)),
...
...
ml/nn/attention.go
View file @
ad95d5b3
...
...
@@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
}
func
AttentionWithSinks
(
ctx
ml
.
Context
,
query
,
key
,
value
,
sinks
ml
.
Tensor
,
scale
float64
,
cache
kvcache
.
Cache
)
ml
.
Tensor
{
ctx
.
Forward
(
query
)
if
key
!=
nil
&&
value
!=
nil
{
if
query
.
Dim
(
0
)
!=
key
.
Dim
(
0
)
{
panic
(
fmt
.
Errorf
(
"d_k in attention operation does not match between query(%v) and key(%v)"
,
query
.
Dim
(
0
),
key
.
Dim
(
0
)))
...
...
@@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
panic
(
fmt
.
Errorf
(
"seq_len_k in attention operation does not match between key(%v) and value(%v)"
,
key
.
Dim
(
2
),
value
.
Dim
(
2
)))
}
ctx
.
Forward
(
key
,
value
)
if
cache
!=
nil
{
cache
.
Put
(
ctx
,
key
,
value
)
}
...
...
model/models/gemma2/model.go
View file @
ad95d5b3
...
...
@@ -138,7 +138,7 @@ type MLP struct {
}
func
(
mlp
*
MLP
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
Options
)
ml
.
Tensor
{
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
GELU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
GELU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenState
)
}
...
...
model/models/gemma3/model_text.go
View file @
ad95d5b3
...
...
@@ -123,7 +123,7 @@ type TextMLP struct {
}
func
(
mlp
*
TextMLP
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
TextConfig
)
ml
.
Tensor
{
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
GELU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
GELU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenState
)
}
...
...
model/models/gemma3n/model_text.go
View file @
ad95d5b3
...
...
@@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
}
active
=
d
.
PerLayerInputGate
.
Forward
(
ctx
,
active
)
active
=
active
.
GELU
(
ctx
)
active
=
active
.
Mul
(
ctx
,
perLayerInput
)
active
=
active
.
GELU
(
ctx
,
perLayerInput
)
active
=
d
.
PerLayerProjection
.
Forward
(
ctx
,
active
)
active
=
d
.
PostPerLayerNorm
.
Forward
(
ctx
,
active
,
opts
.
eps
)
...
...
@@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
hiddenStates
=
hiddenStates
.
Sub
(
ctx
,
cutoff
)
.
RELU
(
ctx
)
}
hiddenStates
=
hiddenStates
.
GELU
(
ctx
)
.
Mul
(
ctx
,
upStates
)
hiddenStates
=
hiddenStates
.
GELU
(
ctx
,
upStates
)
hiddenStates
=
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
return
hiddenStates
}
...
...
model/models/gptoss/model.go
View file @
ad95d5b3
...
...
@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
up
=
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
,
selectedExperts
)
}
hiddenStates
=
gate
.
S
wiGLU
(
ctx
,
up
,
1.702
,
7
)
hiddenStates
=
gate
.
S
ILUAlphaLimit
(
ctx
,
up
,
1.702
,
7
)
experts
:=
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
,
selectedExperts
)
experts
=
experts
.
Mul
(
ctx
,
routingWeights
)
...
...
model/models/llama/model.go
View file @
ad95d5b3
...
...
@@ -118,7 +118,7 @@ type MLP struct {
}
func
(
mlp
*
MLP
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
Options
)
ml
.
Tensor
{
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenState
)
}
...
...
model/models/llama4/model_text.go
View file @
ad95d5b3
...
...
@@ -58,14 +58,14 @@ type TextMLP struct {
}
func
(
mlp
*
TextMLP
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
}
type
TextExperts
struct
{
Gate
*
nn
.
Linear
`gguf:"ffn_gate_exps"`
Up
*
nn
.
Linear
`gguf:"ffn_up_exps"`
Down
*
nn
.
Linear
`gguf:"ffn_down_exps"`
Gate
*
nn
.
Linear
Batch
`gguf:"ffn_gate_exps"`
Up
*
nn
.
Linear
Batch
`gguf:"ffn_up_exps"`
Down
*
nn
.
Linear
Batch
`gguf:"ffn_down_exps"`
}
func
(
e
*
TextExperts
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
,
routerLogits
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
...
...
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
hiddenStates
=
hiddenStates
.
Repeat
(
ctx
,
1
,
opts
.
numExpertsUsed
)
hiddenStates
=
hiddenStates
.
Mul
(
ctx
,
scores
)
upStates
:=
e
.
Up
.
Weight
.
MulmatID
(
ctx
,
hiddenStates
,
experts
)
gateStates
:=
e
.
Gate
.
Weight
.
MulmatID
(
ctx
,
hiddenStates
,
experts
)
downStates
:=
e
.
Down
.
Weight
.
MulmatID
(
ctx
,
upStates
.
Mul
(
ctx
,
gateStates
.
SILU
(
ctx
)),
experts
)
upStates
:=
e
.
Up
.
Forward
(
ctx
,
hiddenStates
,
experts
)
gateStates
:=
e
.
Gate
.
Forward
(
ctx
,
hiddenStates
,
experts
)
downStates
:=
e
.
Down
.
Forward
(
ctx
,
upStates
.
Mul
(
ctx
,
gateStates
.
SILU
(
ctx
)),
experts
)
nextStates
:=
downStates
.
View
(
ctx
,
0
,
hiddenStates
.
Dim
(
0
),
downStates
.
Stride
(
2
),
hiddenStates
.
Dim
(
2
))
for
i
:=
1
;
i
<
opts
.
numExpertsUsed
;
i
++
{
...
...
@@ -96,7 +96,7 @@ type TextSharedExpert struct {
}
func
(
mlp
*
TextSharedExpert
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
}
...
...
model/models/mistral3/model_text.go
View file @
ad95d5b3
...
...
@@ -65,7 +65,7 @@ type MLP struct {
}
func
(
mlp
*
MLP
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenState
)
}
...
...
model/models/mistral3/model_vision.go
View file @
ad95d5b3
...
...
@@ -51,7 +51,7 @@ type VisionMLP struct {
}
func
(
mlp
*
VisionMLP
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
VisionModelOptions
)
ml
.
Tensor
{
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
}
...
...
model/models/mllama/model_text.go
View file @
ad95d5b3
...
...
@@ -58,7 +58,7 @@ type TextMLP struct {
}
func
(
mlp
*
TextMLP
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
TextModelOptions
)
ml
.
Tensor
{
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenState
)
}
...
...
model/models/qwen2/model.go
View file @
ad95d5b3
...
...
@@ -59,7 +59,7 @@ type MLP struct {
}
func
(
mlp
MLP
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
)
ml
.
Tensor
{
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
}
...
...
model/models/qwen25vl/model_text.go
View file @
ad95d5b3
...
...
@@ -90,7 +90,7 @@ type MLP struct {
func
(
mlp
*
MLP
)
Forward
(
ctx
ml
.
Context
,
hiddenState
ml
.
Tensor
,
opts
*
TextOptions
)
ml
.
Tensor
{
// Apply SwiGLU activation gating
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
hiddenState
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenState
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenState
))
// Project back to hidden dimension
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenState
)
}
...
...
model/models/qwen25vl/model_vision.go
View file @
ad95d5b3
...
...
@@ -100,8 +100,7 @@ type VisionMLP struct {
func
(
mlp
*
VisionMLP
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
VisionModelOptions
)
ml
.
Tensor
{
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput
:=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
upOutput
:=
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
)
hiddenStates
=
gateOutput
.
SILU
(
ctx
)
.
Mul
(
ctx
,
upOutput
)
hiddenStates
=
gateOutput
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
}
...
...
model/models/qwen3/model.go
View file @
ad95d5b3
...
...
@@ -30,10 +30,10 @@ func (o Options) headDim() int {
}
type
Attention
struct
{
QueryNorm
*
nn
.
RMSNorm
`gguf:"attn_q_norm"`
Query
*
nn
.
Linear
`gguf:"attn_q"`
Ke
yNorm
*
nn
.
RMSNorm
`gguf:"attn_
k
_norm"`
Quer
yNorm
*
nn
.
RMSNorm
`gguf:"attn_
q
_norm"`
Key
*
nn
.
Linear
`gguf:"attn_k"`
KeyNorm
*
nn
.
RMSNorm
`gguf:"attn_k_norm"`
Value
*
nn
.
Linear
`gguf:"attn_v"`
Output
*
nn
.
Linear
`gguf:"attn_output"`
}
...
...
@@ -65,10 +65,10 @@ type MLP interface {
}
type
sparse
struct
{
Router
*
nn
.
Linear
`gguf:"ffn_gate_inp"`
Gate
*
nn
.
Linear
`gguf:"ffn_gate_exps"`
Up
*
nn
.
Linear
`gguf:"ffn_up_exps"`
Down
*
nn
.
Linear
`gguf:"ffn_down_exps"`
Router
*
nn
.
Linear
`gguf:"ffn_gate_inp"`
Gate
*
nn
.
Linear
Batch
`gguf:"ffn_gate_exps"`
Up
*
nn
.
Linear
Batch
`gguf:"ffn_up_exps"`
Down
*
nn
.
Linear
Batch
`gguf:"ffn_down_exps"`
}
func
(
mlp
*
sparse
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
opts
*
Options
)
ml
.
Tensor
{
...
...
@@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
hiddenStates
=
hiddenStates
.
Reshape
(
ctx
,
hiddenStates
.
Dim
(
0
),
1
,
hiddenStates
.
Dim
(
1
))
upStates
:=
mlp
.
Up
.
Weight
.
MulmatID
(
ctx
,
hiddenStates
,
selectedExperts
)
hiddenStates
=
mlp
.
Gate
.
Weight
.
MulmatID
(
ctx
,
hiddenStates
,
selectedExperts
)
hiddenStates
=
hiddenStates
.
SILU
(
ctx
)
hiddenStates
=
hiddenStates
.
Mul
(
ctx
,
upStates
)
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
,
selectedExperts
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
,
selectedExperts
))
experts
:=
mlp
.
Down
.
Weight
.
MulmatID
(
ctx
,
hiddenStates
,
selectedExperts
)
experts
:=
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
,
selectedExperts
)
experts
=
experts
.
Mul
(
ctx
,
routingWeights
)
nextStates
:=
experts
.
View
(
ctx
,
0
,
experts
.
Dim
(
0
),
experts
.
Stride
(
2
),
experts
.
Dim
(
2
))
...
...
@@ -111,7 +107,8 @@ type dense struct {
}
func
(
mlp
*
dense
)
Forward
(
ctx
ml
.
Context
,
hiddenStates
ml
.
Tensor
,
_
*
Options
)
ml
.
Tensor
{
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
)
.
Mul
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
hiddenStates
=
mlp
.
Gate
.
Forward
(
ctx
,
hiddenStates
)
.
SILU
(
ctx
,
mlp
.
Up
.
Forward
(
ctx
,
hiddenStates
))
return
mlp
.
Down
.
Forward
(
ctx
,
hiddenStates
)
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment