Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
1188f408
Unverified
Commit
1188f408
authored
Oct 28, 2025
by
Michael Yang
Committed by
GitHub
Oct 28, 2025
Browse files
s/From*Slice/From*s/ (#12255)
parent
15c7d30d
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
82 additions
and
63 deletions
+82
-63
kvcache/causal.go
kvcache/causal.go
+2
-2
kvcache/causal_test.go
kvcache/causal_test.go
+9
-9
ml/backend.go
ml/backend.go
+6
-3
ml/backend/ggml/ggml.go
ml/backend/ggml/ggml.go
+39
-23
model/models/bert/embed.go
model/models/bert/embed.go
+1
-1
model/models/deepseek2/model.go
model/models/deepseek2/model.go
+1
-1
model/models/gemma2/model.go
model/models/gemma2/model.go
+1
-1
model/models/gemma3/model.go
model/models/gemma3/model.go
+1
-1
model/models/gemma3/model_text.go
model/models/gemma3/model_text.go
+1
-1
model/models/gemma3n/model_text.go
model/models/gemma3n/model_text.go
+2
-2
model/models/gptoss/model.go
model/models/gptoss/model.go
+2
-2
model/models/llama/model.go
model/models/llama/model.go
+1
-1
model/models/llama4/model.go
model/models/llama4/model.go
+3
-3
model/models/llama4/model_text.go
model/models/llama4/model_text.go
+1
-1
model/models/llama4/model_vision.go
model/models/llama4/model_vision.go
+1
-1
model/models/mistral3/model.go
model/models/mistral3/model.go
+2
-2
model/models/mistral3/model_vision.go
model/models/mistral3/model_vision.go
+3
-3
model/models/mllama/model.go
model/models/mllama/model.go
+3
-3
model/models/qwen2/model.go
model/models/qwen2/model.go
+1
-1
model/models/qwen25vl/model.go
model/models/qwen25vl/model.go
+2
-2
No files found.
kvcache/causal.go
View file @
1188f408
...
@@ -393,7 +393,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
...
@@ -393,7 +393,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
mask
[
i
]
=
float32
(
math
.
Inf
(
-
1
))
mask
[
i
]
=
float32
(
math
.
Inf
(
-
1
))
}
}
maskTensor
:=
ctx
.
Input
()
.
FromFloat
Slice
(
mask
,
length
,
batchSize
)
maskTensor
:=
ctx
.
Input
()
.
FromFloat
s
(
mask
,
length
,
batchSize
)
if
c
.
config
.
MaskDType
!=
ml
.
DTypeF32
{
if
c
.
config
.
MaskDType
!=
ml
.
DTypeF32
{
maskTensor
=
maskTensor
.
Cast
(
ctx
,
c
.
config
.
MaskDType
)
maskTensor
=
maskTensor
.
Cast
(
ctx
,
c
.
config
.
MaskDType
)
...
@@ -725,7 +725,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
...
@@ -725,7 +725,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
offsets
=
offsets
[
batchFirst
:
batchLast
+
1
]
offsets
=
offsets
[
batchFirst
:
batchLast
+
1
]
ctx
:=
c
.
backend
.
NewContext
()
ctx
:=
c
.
backend
.
NewContext
()
kShift
:=
ctx
.
Input
()
.
FromInt
Slice
(
offsets
,
len
(
offsets
))
kShift
:=
ctx
.
Input
()
.
FromInt
s
(
offsets
,
len
(
offsets
))
for
i
,
key
:=
range
c
.
keys
{
for
i
,
key
:=
range
c
.
keys
{
if
key
==
nil
{
if
key
==
nil
{
...
...
kvcache/causal_test.go
View file @
1188f408
...
@@ -477,7 +477,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
...
@@ -477,7 +477,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
}
}
cache
.
SetLayer
(
0
)
cache
.
SetLayer
(
0
)
tensor
:=
context
.
FromFloat
Slice
(
test
.
in
,
test
.
inShape
...
)
tensor
:=
context
.
FromFloat
s
(
test
.
in
,
test
.
inShape
...
)
cache
.
Put
(
context
,
tensor
,
tensor
)
cache
.
Put
(
context
,
tensor
,
tensor
)
out
,
_
,
mask
:=
cache
.
Get
(
context
)
out
,
_
,
mask
:=
cache
.
Get
(
context
)
...
@@ -519,7 +519,7 @@ func TestCanResume(t *testing.T) {
...
@@ -519,7 +519,7 @@ func TestCanResume(t *testing.T) {
}
}
cache
.
SetLayer
(
0
)
cache
.
SetLayer
(
0
)
tensor
:=
context
.
FromFloat
Slice
([]
float32
{
1
,
2
,
3
,
4
,
5
},
1
,
1
,
5
)
tensor
:=
context
.
FromFloat
s
([]
float32
{
1
,
2
,
3
,
4
,
5
},
1
,
1
,
5
)
cache
.
Put
(
context
,
tensor
,
tensor
)
cache
.
Put
(
context
,
tensor
,
tensor
)
// with window size 4, nothing has slid out of the window yet
// with window size 4, nothing has slid out of the window yet
...
@@ -549,7 +549,7 @@ func TestCanResume(t *testing.T) {
...
@@ -549,7 +549,7 @@ func TestCanResume(t *testing.T) {
}
}
cache
.
SetLayer
(
0
)
cache
.
SetLayer
(
0
)
tensor
=
context
.
FromFloat
Slice
([]
float32
{
6
},
1
,
1
,
1
)
tensor
=
context
.
FromFloat
s
([]
float32
{
6
},
1
,
1
,
1
)
cache
.
Put
(
context
,
tensor
,
tensor
)
cache
.
Put
(
context
,
tensor
,
tensor
)
// only the latest position has overlapping windows
// only the latest position has overlapping windows
...
@@ -594,7 +594,7 @@ func TestCanResumeSWAMem(t *testing.T) {
...
@@ -594,7 +594,7 @@ func TestCanResumeSWAMem(t *testing.T) {
}
}
cache
.
SetLayer
(
0
)
cache
.
SetLayer
(
0
)
tensor
:=
context
.
FromFloat
Slice
([]
float32
{
1
,
2
,
3
,
4
,
5
,
6
,
7
},
1
,
1
,
7
)
tensor
:=
context
.
FromFloat
s
([]
float32
{
1
,
2
,
3
,
4
,
5
,
6
,
7
},
1
,
1
,
7
)
cache
.
Put
(
context
,
tensor
,
tensor
)
cache
.
Put
(
context
,
tensor
,
tensor
)
// shift window by adding position 7
// shift window by adding position 7
...
@@ -607,7 +607,7 @@ func TestCanResumeSWAMem(t *testing.T) {
...
@@ -607,7 +607,7 @@ func TestCanResumeSWAMem(t *testing.T) {
}
}
cache
.
SetLayer
(
0
)
cache
.
SetLayer
(
0
)
tensor
=
context
.
FromFloat
Slice
([]
float32
{
8
},
1
,
1
,
1
)
tensor
=
context
.
FromFloat
s
([]
float32
{
8
},
1
,
1
,
1
)
cache
.
Put
(
context
,
tensor
,
tensor
)
cache
.
Put
(
context
,
tensor
,
tensor
)
// only the latest position has overlapping windows
// only the latest position has overlapping windows
...
@@ -670,7 +670,7 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
...
@@ -670,7 +670,7 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return
c
.
Empty
(
dtype
,
shape
...
)
return
c
.
Empty
(
dtype
,
shape
...
)
}
}
func
(
c
*
testContext
)
FromFloat
Slice
(
s
[]
float32
,
shape
...
int
)
ml
.
Tensor
{
func
(
c
*
testContext
)
FromFloat
s
(
s
[]
float32
,
shape
...
int
)
ml
.
Tensor
{
t
:=
c
.
Empty
(
ml
.
DTypeF32
,
shape
...
)
.
(
*
testTensor
)
t
:=
c
.
Empty
(
ml
.
DTypeF32
,
shape
...
)
.
(
*
testTensor
)
copy
(
t
.
data
,
s
)
copy
(
t
.
data
,
s
)
...
@@ -678,13 +678,13 @@ func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
...
@@ -678,13 +678,13 @@ func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
return
t
return
t
}
}
func
(
c
*
testContext
)
FromInt
Slice
(
s
[]
int32
,
shape
...
int
)
ml
.
Tensor
{
func
(
c
*
testContext
)
FromInt
s
(
s
[]
int32
,
shape
...
int
)
ml
.
Tensor
{
f
:=
make
([]
float32
,
len
(
s
))
f
:=
make
([]
float32
,
len
(
s
))
for
i
:=
range
f
{
for
i
:=
range
f
{
f
[
i
]
=
float32
(
s
[
i
])
f
[
i
]
=
float32
(
s
[
i
])
}
}
out
:=
c
.
FromFloat
Slice
(
f
,
shape
...
)
out
:=
c
.
FromFloat
s
(
f
,
shape
...
)
out
.
(
*
testTensor
)
.
dtype
=
ml
.
DTypeI32
out
.
(
*
testTensor
)
.
dtype
=
ml
.
DTypeI32
return
out
return
out
...
@@ -696,7 +696,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
...
@@ -696,7 +696,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
s
=
append
(
s
,
i
)
s
=
append
(
s
,
i
)
}
}
out
:=
c
.
FromFloat
Slice
(
s
,
len
(
s
))
out
:=
c
.
FromFloat
s
(
s
,
len
(
s
))
out
.
(
*
testTensor
)
.
dtype
=
dtype
out
.
(
*
testTensor
)
.
dtype
=
dtype
return
out
return
out
}
}
...
...
ml/backend.go
View file @
1188f408
...
@@ -98,8 +98,9 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
...
@@ -98,8 +98,9 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
type
Context
interface
{
type
Context
interface
{
Empty
(
dtype
DType
,
shape
...
int
)
Tensor
Empty
(
dtype
DType
,
shape
...
int
)
Tensor
Zeros
(
dtype
DType
,
shape
...
int
)
Tensor
Zeros
(
dtype
DType
,
shape
...
int
)
Tensor
FromFloatSlice
(
s
[]
float32
,
shape
...
int
)
Tensor
FromBytes
(
dtype
DType
,
s
[]
byte
,
shape
...
int
)
Tensor
FromIntSlice
(
s
[]
int32
,
shape
...
int
)
Tensor
FromFloats
(
s
[]
float32
,
shape
...
int
)
Tensor
FromInts
(
s
[]
int32
,
shape
...
int
)
Tensor
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
Arange
(
start
,
stop
,
step
float32
,
dtype
DType
)
Tensor
Arange
(
start
,
stop
,
step
float32
,
dtype
DType
)
Tensor
...
@@ -136,7 +137,9 @@ type Tensor interface {
...
@@ -136,7 +137,9 @@ type Tensor interface {
Bytes
()
[]
byte
Bytes
()
[]
byte
Floats
()
[]
float32
Floats
()
[]
float32
SetValueFromIntSlice
(
s
[]
int32
)
FromBytes
([]
byte
)
FromFloats
([]
float32
)
FromInts
([]
int32
)
Neg
(
ctx
Context
)
Tensor
Neg
(
ctx
Context
)
Tensor
Add
(
ctx
Context
,
t2
Tensor
)
Tensor
Add
(
ctx
Context
,
t2
Tensor
)
Tensor
...
...
ml/backend/ggml/ggml.go
View file @
1188f408
...
@@ -12,6 +12,7 @@ import "C"
...
@@ -12,6 +12,7 @@ import "C"
import
(
import
(
"context"
"context"
"encoding/binary"
"errors"
"errors"
"fmt"
"fmt"
"io"
"io"
...
@@ -871,7 +872,7 @@ func pad(length, pad C.size_t) C.size_t {
...
@@ -871,7 +872,7 @@ func pad(length, pad C.size_t) C.size_t {
return
((
length
+
pad
-
1
)
/
pad
)
*
pad
return
((
length
+
pad
-
1
)
/
pad
)
*
pad
}
}
func
(
c
*
Context
)
newTensor
(
dtype
ml
.
DType
,
shape
[]
int
)
ml
.
Tensor
{
func
(
c
*
Context
)
newTensor
(
dtype
ml
.
DType
,
shape
[]
int
)
*
Tensor
{
if
c
.
buft
==
nil
{
if
c
.
buft
==
nil
{
panic
(
"set Input or Layer before creating tensors"
)
panic
(
"set Input or Layer before creating tensors"
)
}
}
...
@@ -915,7 +916,7 @@ func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
...
@@ -915,7 +916,7 @@ func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
func
(
c
*
Context
)
Zeros
(
dtype
ml
.
DType
,
shape
...
int
)
ml
.
Tensor
{
func
(
c
*
Context
)
Zeros
(
dtype
ml
.
DType
,
shape
...
int
)
ml
.
Tensor
{
t
:=
c
.
newTensor
(
dtype
,
shape
)
t
:=
c
.
newTensor
(
dtype
,
shape
)
if
c
.
b
.
allocMemory
{
if
c
.
b
.
allocMemory
{
C
.
ggml_set_zero
(
t
.
(
*
Tensor
)
.
t
)
C
.
ggml_set_zero
(
t
.
t
)
}
}
return
t
return
t
}
}
...
@@ -936,25 +937,34 @@ func checkShape[S ~[]E, E any](s S, shape ...int) {
...
@@ -936,25 +937,34 @@ func checkShape[S ~[]E, E any](s S, shape ...int) {
}
}
}
}
func
(
c
*
Context
)
FromFloatSlice
(
s
[]
float32
,
shape
...
int
)
ml
.
Tensor
{
func
(
c
Context
)
FromBytes
(
dtype
ml
.
DType
,
s
[]
uint8
,
shape
...
int
)
ml
.
Tensor
{
// Unchecked to handle quantized types
t
:=
c
.
newTensor
(
dtype
,
shape
)
if
c
.
b
.
allocMemory
{
t
.
FromBytes
(
s
)
}
return
t
}
func
(
c
*
Context
)
FromFloats
(
s
[]
float32
,
shape
...
int
)
ml
.
Tensor
{
checkShape
(
s
,
shape
...
)
checkShape
(
s
,
shape
...
)
t
:=
c
.
newTensor
(
ml
.
DTypeF32
,
shape
)
t
:=
c
.
newTensor
(
ml
.
DTypeF32
,
shape
)
if
c
.
b
.
allocMemory
&&
len
(
s
)
>
0
{
if
c
.
b
.
allocMemory
{
C
.
ggml_backend_tensor_set
(
t
.
(
*
Tensor
)
.
t
,
unsafe
.
Pointer
(
&
s
[
0
]),
0
,
C
.
ggml_nbytes
(
t
.
(
*
Tensor
)
.
t
)
)
t
.
FromFloats
(
s
)
}
}
return
t
return
t
}
}
func
(
c
*
Context
)
FromInt
Slice
(
s
[]
int32
,
shape
...
int
)
ml
.
Tensor
{
func
(
c
*
Context
)
FromInt
s
(
s
[]
int32
,
shape
...
int
)
ml
.
Tensor
{
checkShape
(
s
,
shape
...
)
checkShape
(
s
,
shape
...
)
t
:=
c
.
newTensor
(
ml
.
DTypeI32
,
shape
)
t
:=
c
.
newTensor
(
ml
.
DTypeI32
,
shape
)
if
c
.
b
.
allocMemory
{
if
c
.
b
.
allocMemory
&&
len
(
s
)
>
0
{
t
.
FromInts
(
s
)
C
.
ggml_backend_tensor_set
(
t
.
(
*
Tensor
)
.
t
,
unsafe
.
Pointer
(
&
s
[
0
]),
0
,
C
.
ggml_nbytes
(
t
.
(
*
Tensor
)
.
t
))
}
}
return
t
return
t
...
@@ -975,7 +985,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
...
@@ -975,7 +985,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
arange
=
append
(
arange
,
int32
(
i
))
arange
=
append
(
arange
,
int32
(
i
))
}
}
return
c
.
Input
()
.
FromInt
Slice
(
arange
,
len
(
arange
))
return
c
.
Input
()
.
FromInt
s
(
arange
,
len
(
arange
))
default
:
default
:
panic
(
"unsupported dtype for arange"
)
panic
(
"unsupported dtype for arange"
)
}
}
...
@@ -1045,10 +1055,26 @@ func (t *Tensor) Floats() (data []float32) {
...
@@ -1045,10 +1055,26 @@ func (t *Tensor) Floats() (data []float32) {
return
return
}
}
func
(
t
*
Tensor
)
SetValueFromIntSlice
(
s
[]
int32
)
{
func
tensorSet
[
S
~
[]
E
,
E
byte
|
float32
|
int32
](
t
*
Tensor
,
s
S
)
{
if
len
(
s
)
>
0
{
if
len
(
s
)
==
0
{
C
.
ggml_backend_tensor_set
(
t
.
t
,
unsafe
.
Pointer
(
&
s
[
0
]),
0
,
C
.
ggml_nbytes
(
t
.
t
))
return
}
if
int
(
C
.
ggml_nbytes
(
t
.
t
))
!=
len
(
s
)
*
binary
.
Size
(
s
[
0
])
{
panic
(
"data size does not match tensor size"
)
}
}
C
.
ggml_backend_tensor_set
(
t
.
t
,
unsafe
.
Pointer
(
&
s
[
0
]),
0
,
C
.
ggml_nbytes
(
t
.
t
))
}
func
(
t
*
Tensor
)
FromBytes
(
s
[]
byte
)
{
tensorSet
(
t
,
s
)
}
func
(
t
*
Tensor
)
FromFloats
(
s
[]
float32
)
{
tensorSet
(
t
,
s
)
}
func
(
t
*
Tensor
)
FromInts
(
s
[]
int32
)
{
tensorSet
(
t
,
s
)
}
}
func
(
t
*
Tensor
)
DType
()
ml
.
DType
{
func
(
t
*
Tensor
)
DType
()
ml
.
DType
{
...
@@ -1622,13 +1648,3 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
...
@@ -1622,13 +1648,3 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
t
:
C
.
ggml_clamp
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
float
(
min
),
C
.
float
(
max
)),
t
:
C
.
ggml_clamp
(
ctx
.
(
*
Context
)
.
ctx
,
t
.
t
,
C
.
float
(
min
),
C
.
float
(
max
)),
}
}
}
}
func
(
c
Context
)
FromBytes
(
dtype
ml
.
DType
,
s
[]
uint8
,
shape
...
int
)
ml
.
Tensor
{
// Unchecked to handle quantized types
t
:=
c
.
newTensor
(
dtype
,
shape
)
if
c
.
b
.
allocMemory
&&
len
(
s
)
>
0
{
C
.
ggml_backend_tensor_set
(
t
.
(
*
Tensor
)
.
t
,
unsafe
.
Pointer
(
&
s
[
0
]),
0
,
C
.
ggml_nbytes
(
t
.
(
*
Tensor
)
.
t
))
}
return
t
}
model/models/bert/embed.go
View file @
1188f408
...
@@ -30,7 +30,7 @@ type Model struct {
...
@@ -30,7 +30,7 @@ type Model struct {
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
m
.
TypeEmbedding
.
Weight
.
View
(
ctx
,
0
,
m
.
hiddenSize
))
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
m
.
TypeEmbedding
.
Weight
.
View
(
ctx
,
0
,
m
.
hiddenSize
))
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
m
.
PositionEmbedding
.
Forward
(
ctx
,
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))))
hiddenStates
=
hiddenStates
.
Add
(
ctx
,
m
.
PositionEmbedding
.
Forward
(
ctx
,
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))))
hiddenStates
=
m
.
TokenEmbeddingNorm
.
Forward
(
ctx
,
hiddenStates
,
m
.
eps
)
hiddenStates
=
m
.
TokenEmbeddingNorm
.
Forward
(
ctx
,
hiddenStates
,
m
.
eps
)
for
_
,
layer
:=
range
m
.
Layers
{
for
_
,
layer
:=
range
m
.
Layers
{
...
...
model/models/deepseek2/model.go
View file @
1188f408
...
@@ -302,7 +302,7 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
...
@@ -302,7 +302,7 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
...
...
model/models/gemma2/model.go
View file @
1188f408
...
@@ -175,7 +175,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
...
@@ -175,7 +175,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenState
=
hiddenState
.
Scale
(
ctx
,
math
.
Sqrt
(
float64
(
m
.
Options
.
hiddenSize
)))
hiddenState
=
hiddenState
.
Scale
(
ctx
,
math
.
Sqrt
(
float64
(
m
.
Options
.
hiddenSize
)))
...
...
model/models/gemma3/model.go
View file @
1188f408
...
@@ -101,7 +101,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
...
@@ -101,7 +101,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return
nil
,
err
return
nil
,
err
}
}
pixelValues
:=
ctx
.
Input
()
.
FromFloat
Slice
(
f32s
,
pixelValues
:=
ctx
.
Input
()
.
FromFloat
s
(
f32s
,
m
.
ImageProcessor
.
imageSize
,
m
.
ImageProcessor
.
imageSize
,
m
.
ImageProcessor
.
imageSize
,
m
.
ImageProcessor
.
imageSize
,
m
.
ImageProcessor
.
numChannels
,
m
.
ImageProcessor
.
numChannels
,
...
...
model/models/gemma3/model_text.go
View file @
1188f408
...
@@ -163,7 +163,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
...
@@ -163,7 +163,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
}
}
func
(
m
*
TextModel
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
,
cache
kvcache
.
Cache
)
ml
.
Tensor
{
func
(
m
*
TextModel
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
,
cache
kvcache
.
Cache
)
ml
.
Tensor
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenState
=
hiddenState
.
Scale
(
ctx
,
math
.
Sqrt
(
float64
(
m
.
TextConfig
.
hiddenSize
)))
hiddenState
=
hiddenState
.
Scale
(
ctx
,
math
.
Sqrt
(
float64
(
m
.
TextConfig
.
hiddenSize
)))
...
...
model/models/gemma3n/model_text.go
View file @
1188f408
...
@@ -29,9 +29,9 @@ type TextModel struct {
...
@@ -29,9 +29,9 @@ type TextModel struct {
}
}
func
(
m
*
TextModel
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
,
cache
kvcache
.
Cache
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
TextModel
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
,
cache
kvcache
.
Cache
)
(
ml
.
Tensor
,
error
)
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
// Create a tensor of a single float32 value of 1.0 to use for altup correction
// Create a tensor of a single float32 value of 1.0 to use for altup correction
one
:=
ctx
.
Input
()
.
FromFloat
Slice
([]
float32
{
1.0
},
1
)
one
:=
ctx
.
Input
()
.
FromFloat
s
([]
float32
{
1.0
},
1
)
inputs
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
,
math
.
Sqrt
(
float64
(
m
.
hiddenSize
)))
inputs
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
,
math
.
Sqrt
(
float64
(
m
.
hiddenSize
)))
inputsPerLayer
:=
m
.
PerLayerProjector
.
Forward
(
ctx
,
batch
,
inputs
,
&
m
.
TextOptions
)
inputsPerLayer
:=
m
.
PerLayerProjector
.
Forward
(
ctx
,
batch
,
inputs
,
&
m
.
TextOptions
)
...
...
model/models/gptoss/model.go
View file @
1188f408
...
@@ -30,9 +30,9 @@ type Transformer struct {
...
@@ -30,9 +30,9 @@ type Transformer struct {
// Forward implements model.Model.
// Forward implements model.Model.
func
(
m
*
Transformer
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Transformer
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
one
:=
ctx
.
Input
()
.
FromFloat
Slice
([]
float32
{
1
},
1
)
one
:=
ctx
.
Input
()
.
FromFloat
s
([]
float32
{
1
},
1
)
for
i
,
block
:=
range
m
.
TransformerBlocks
{
for
i
,
block
:=
range
m
.
TransformerBlocks
{
m
.
Cache
.
SetLayer
(
i
)
m
.
Cache
.
SetLayer
(
i
)
if
c
,
ok
:=
m
.
Cache
.
(
*
kvcache
.
WrapperCache
);
ok
{
if
c
,
ok
:=
m
.
Cache
.
(
*
kvcache
.
WrapperCache
);
ok
{
...
...
model/models/llama/model.go
View file @
1188f408
...
@@ -179,7 +179,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
...
@@ -179,7 +179,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
...
...
model/models/llama4/model.go
View file @
1188f408
...
@@ -76,7 +76,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
...
@@ -76,7 +76,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return
nil
,
err
return
nil
,
err
}
}
tilesLocal
:=
ctx
.
Input
()
.
FromFloat
Slice
(
pixelsLocal
,
size
.
X
,
size
.
Y
,
m
.
numChannels
)
tilesLocal
:=
ctx
.
Input
()
.
FromFloat
s
(
pixelsLocal
,
size
.
X
,
size
.
Y
,
m
.
numChannels
)
ratioW
,
ratioH
:=
size
.
X
/
m
.
imageSize
,
size
.
Y
/
m
.
imageSize
ratioW
,
ratioH
:=
size
.
X
/
m
.
imageSize
,
size
.
Y
/
m
.
imageSize
...
@@ -87,7 +87,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
...
@@ -87,7 +87,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
pixelValues
:=
tilesLocal
pixelValues
:=
tilesLocal
if
len
(
pixelsGlobal
)
>
0
{
if
len
(
pixelsGlobal
)
>
0
{
tilesGlobal
:=
ctx
.
Input
()
.
FromFloat
Slice
(
pixelsGlobal
,
m
.
imageSize
,
m
.
imageSize
,
m
.
numChannels
)
tilesGlobal
:=
ctx
.
Input
()
.
FromFloat
s
(
pixelsGlobal
,
m
.
imageSize
,
m
.
imageSize
,
m
.
numChannels
)
pixelValues
=
pixelValues
.
Concat
(
ctx
,
tilesGlobal
,
3
)
pixelValues
=
pixelValues
.
Concat
(
ctx
,
tilesGlobal
,
3
)
}
}
...
@@ -174,7 +174,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
...
@@ -174,7 +174,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
batch
.
Outputs
,
batch
,
m
.
Cache
),
nil
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
batch
.
Outputs
,
batch
,
m
.
Cache
),
nil
}
}
...
...
model/models/llama4/model_text.go
View file @
1188f408
...
@@ -211,7 +211,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
...
@@ -211,7 +211,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
scales
[
i
]
=
float32
(
math
.
Log
(
math
.
Floor
(((
float64
(
p
)
+
1.0
)
/
float64
(
m
.
attentionFloorScale
))
+
1.0
))
*
m
.
attentionScale
+
1.0
)
scales
[
i
]
=
float32
(
math
.
Log
(
math
.
Floor
(((
float64
(
p
)
+
1.0
)
/
float64
(
m
.
attentionFloorScale
))
+
1.0
))
*
m
.
attentionScale
+
1.0
)
}
}
attentionScales
=
ctx
.
Input
()
.
FromFloat
Slice
(
scales
,
1
,
1
,
len
(
scales
))
attentionScales
=
ctx
.
Input
()
.
FromFloat
s
(
scales
,
1
,
1
,
len
(
scales
))
}
}
for
i
,
layer
:=
range
m
.
Layers
{
for
i
,
layer
:=
range
m
.
Layers
{
...
...
model/models/llama4/model_vision.go
View file @
1188f408
...
@@ -245,7 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
...
@@ -245,7 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
}
}
}
}
ropeFreqs
:=
ctx
.
Input
()
.
FromFloat
Slice
(
freqs
,
freqDim
/
2
,
numPatches
,
2
)
ropeFreqs
:=
ctx
.
Input
()
.
FromFloat
s
(
freqs
,
freqDim
/
2
,
numPatches
,
2
)
ropeFreqs
=
ropeFreqs
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
.
Contiguous
(
ctx
)
ropeFreqs
=
ropeFreqs
.
Permute
(
ctx
,
0
,
2
,
1
,
3
)
.
Contiguous
(
ctx
)
ropeFreqs
=
ropeFreqs
.
Reshape
(
ctx
,
freqDim
,
1
,
numPatches
)
ropeFreqs
=
ropeFreqs
.
Reshape
(
ctx
,
freqDim
,
1
,
numPatches
)
...
...
model/models/mistral3/model.go
View file @
1188f408
...
@@ -114,7 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
...
@@ -114,7 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return
nil
,
err
return
nil
,
err
}
}
pixelValues
:=
ctx
.
Input
()
.
FromFloat
Slice
(
f32s
,
size
.
X
,
size
.
Y
,
m
.
ImageProcessor
.
numChannels
)
pixelValues
:=
ctx
.
Input
()
.
FromFloat
s
(
f32s
,
size
.
X
,
size
.
Y
,
m
.
ImageProcessor
.
numChannels
)
visionOutputs
:=
m
.
VisionModel
.
Forward
(
ctx
,
pixelValues
)
visionOutputs
:=
m
.
VisionModel
.
Forward
(
ctx
,
pixelValues
)
features
,
size
:=
m
.
MultiModalProjector
.
Forward
(
ctx
,
visionOutputs
,
size
)
features
,
size
:=
m
.
MultiModalProjector
.
Forward
(
ctx
,
visionOutputs
,
size
)
...
@@ -158,7 +158,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
...
@@ -158,7 +158,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
batch
.
Outputs
,
batch
,
m
.
Cache
),
nil
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
batch
.
Outputs
,
batch
,
m
.
Cache
),
nil
}
}
...
...
model/models/mistral3/model_vision.go
View file @
1188f408
...
@@ -110,8 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor)
...
@@ -110,8 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor)
}
}
}
}
h
:=
ctx
.
Input
()
.
FromFloat
Slice
(
frequenciesHeight
,
maxPatchesPerSide
,
frequencies
/
2
)
h
:=
ctx
.
Input
()
.
FromFloat
s
(
frequenciesHeight
,
maxPatchesPerSide
,
frequencies
/
2
)
w
:=
ctx
.
Input
()
.
FromFloat
Slice
(
frequenciesWidth
,
maxPatchesPerSide
,
frequencies
/
2
)
w
:=
ctx
.
Input
()
.
FromFloat
s
(
frequenciesWidth
,
maxPatchesPerSide
,
frequencies
/
2
)
h
=
h
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
h
=
h
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
w
=
w
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
w
=
w
.
Permute
(
ctx
,
1
,
0
,
2
,
3
)
.
Contiguous
(
ctx
)
...
@@ -144,7 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
...
@@ -144,7 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
}
}
}
}
positionIDs
:=
ctx
.
Input
()
.
FromInt
Slice
(
positions
,
len
(
positions
))
positionIDs
:=
ctx
.
Input
()
.
FromInt
s
(
positions
,
len
(
positions
))
positionEmbedding
:=
m
.
positionalEmbedding
(
ctx
,
positionIDs
)
positionEmbedding
:=
m
.
positionalEmbedding
(
ctx
,
positionIDs
)
cos
,
sin
:=
positionEmbedding
.
Cos
(
ctx
),
positionEmbedding
.
Sin
(
ctx
)
cos
,
sin
:=
positionEmbedding
.
Cos
(
ctx
),
positionEmbedding
.
Sin
(
ctx
)
...
...
model/models/mllama/model.go
View file @
1188f408
...
@@ -80,8 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
...
@@ -80,8 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
f32s
=
f32s
[
:
m
.
imageSize
*
m
.
imageSize
*
m
.
numChannels
*
m
.
maxNumTiles
]
f32s
=
f32s
[
:
m
.
imageSize
*
m
.
imageSize
*
m
.
numChannels
*
m
.
maxNumTiles
]
}
}
pixelValues
:=
ctx
.
Input
()
.
FromFloat
Slice
(
f32s
,
m
.
imageSize
,
m
.
imageSize
,
m
.
numChannels
,
m
.
maxNumTiles
)
pixelValues
:=
ctx
.
Input
()
.
FromFloat
s
(
f32s
,
m
.
imageSize
,
m
.
imageSize
,
m
.
numChannels
,
m
.
maxNumTiles
)
aspectRatio
:=
ctx
.
Input
()
.
FromInt
Slice
([]
int32
{
int32
(
ratio
.
rank
)},
1
)
aspectRatio
:=
ctx
.
Input
()
.
FromInt
s
([]
int32
{
int32
(
ratio
.
rank
)},
1
)
positionIDs
:=
ctx
.
Arange
(
0
,
1601
,
1
,
ml
.
DTypeI32
)
positionIDs
:=
ctx
.
Arange
(
0
,
1601
,
1
,
ml
.
DTypeI32
)
crossAttentionStates
:=
m
.
VisionModel
.
Forward
(
ctx
,
pixelValues
,
positionIDs
,
aspectRatio
)
crossAttentionStates
:=
m
.
VisionModel
.
Forward
(
ctx
,
pixelValues
,
positionIDs
,
aspectRatio
)
...
@@ -106,7 +106,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
...
@@ -106,7 +106,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
crossAttentionStates
=
batch
.
Multimodal
[
len
(
batch
.
Multimodal
)
-
1
]
.
Multimodal
[
0
]
.
Tensor
crossAttentionStates
=
batch
.
Multimodal
[
len
(
batch
.
Multimodal
)
-
1
]
.
Multimodal
[
0
]
.
Tensor
}
}
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
// TODO: attention mask, cross attention mask
// TODO: attention mask, cross attention mask
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
batch
.
Outputs
,
crossAttentionStates
,
nil
,
m
.
Cache
.
(
*
kvcache
.
WrapperCache
)),
nil
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
batch
.
Outputs
,
crossAttentionStates
,
nil
,
m
.
Cache
.
(
*
kvcache
.
WrapperCache
)),
nil
...
...
model/models/qwen2/model.go
View file @
1188f408
...
@@ -102,7 +102,7 @@ type Model struct {
...
@@ -102,7 +102,7 @@ type Model struct {
// Forward implements model.Model.
// Forward implements model.Model.
func
(
m
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
hiddenStates
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
batch
.
Inputs
)
...
...
model/models/qwen25vl/model.go
View file @
1188f408
...
@@ -69,7 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
...
@@ -69,7 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
m
.
ImageProcessor
.
patchSize
*
m
.
ImageProcessor
.
patchSize
m
.
ImageProcessor
.
patchSize
*
m
.
ImageProcessor
.
patchSize
numPatches
:=
grid
.
Temporal
*
grid
.
Height
*
grid
.
Width
numPatches
:=
grid
.
Temporal
*
grid
.
Height
*
grid
.
Width
pixelValues
:=
ctx
.
Input
()
.
FromFloat
Slice
(
f32s
,
patchDim
,
numPatches
)
pixelValues
:=
ctx
.
Input
()
.
FromFloat
s
(
f32s
,
patchDim
,
numPatches
)
return
pixelValues
,
grid
,
nil
return
pixelValues
,
grid
,
nil
}
}
...
@@ -139,7 +139,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
...
@@ -139,7 +139,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
positions
:=
ctx
.
Input
()
.
FromInt
Slice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
positions
:=
ctx
.
Input
()
.
FromInt
s
(
batch
.
Positions
,
len
(
batch
.
Positions
))
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
batch
.
Outputs
,
batch
,
m
.
Cache
)
return
m
.
TextModel
.
Forward
(
ctx
,
batch
.
Inputs
,
positions
,
batch
.
Outputs
,
batch
,
m
.
Cache
)
}
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment