Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
0c220935
Commit
0c220935
authored
Mar 19, 2025
by
Jesse Gross
Committed by
Jesse Gross
Mar 20, 2025
Browse files
input: Rename Options to Batch
Options is no longer very descriptive of this struct.
parent
ffbfe833
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
73 additions
and
60 deletions
+73
-60
kvcache/cache.go
kvcache/cache.go
+1
-1
kvcache/causal.go
kvcache/causal.go
+6
-6
kvcache/causal_test.go
kvcache/causal_test.go
+1
-1
kvcache/encoder.go
kvcache/encoder.go
+3
-3
kvcache/wrapper.go
kvcache/wrapper.go
+4
-4
model/input/input.go
model/input/input.go
+19
-6
model/model.go
model/model.go
+7
-7
model/model_test.go
model/model_test.go
+1
-1
model/models/gemma2/model.go
model/models/gemma2/model.go
+4
-4
model/models/gemma3/model.go
model/models/gemma3/model.go
+5
-5
model/models/gemma3/model_text.go
model/models/gemma3/model_text.go
+2
-2
model/models/llama/model.go
model/models/llama/model.go
+4
-4
model/models/mllama/model.go
model/models/mllama/model.go
+6
-6
runner/ollamarunner/runner.go
runner/ollamarunner/runner.go
+10
-10
No files found.
kvcache/cache.go
View file @
0c220935
...
@@ -52,7 +52,7 @@ type Cache interface {
...
@@ -52,7 +52,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
// entry in positions and seqs.
StartForward
(
ctx
ml
.
Context
,
opts
input
.
Options
)
error
StartForward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix
(
srcSeq
,
dstSeq
int
,
len
int32
)
CopyPrefix
(
srcSeq
,
dstSeq
int
,
len
int32
)
...
...
kvcache/causal.go
View file @
0c220935
...
@@ -140,10 +140,10 @@ func (c *Causal) Close() {
...
@@ -140,10 +140,10 @@ func (c *Causal) Close() {
}
}
}
}
func
(
c
*
Causal
)
StartForward
(
ctx
ml
.
Context
,
opts
input
.
Options
)
error
{
func
(
c
*
Causal
)
StartForward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
error
{
c
.
curBatchSize
=
len
(
opts
.
Positions
)
c
.
curBatchSize
=
len
(
batch
.
Positions
)
c
.
curSequences
=
opts
.
Sequences
c
.
curSequences
=
batch
.
Sequences
c
.
curPositions
=
opts
.
Positions
c
.
curPositions
=
batch
.
Positions
c
.
opts
.
Except
=
nil
c
.
opts
.
Except
=
nil
var
err
error
var
err
error
...
@@ -157,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
...
@@ -157,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
}
}
c
.
curCellRange
=
newRange
()
c
.
curCellRange
=
newRange
()
for
i
,
pos
:=
range
opts
.
Positions
{
for
i
,
pos
:=
range
batch
.
Positions
{
seq
:=
opts
.
Sequences
[
i
]
seq
:=
batch
.
Sequences
[
i
]
c
.
cells
[
c
.
curLoc
+
i
]
=
cacheCell
{
pos
:
pos
,
sequences
:
[]
int
{
seq
}}
c
.
cells
[
c
.
curLoc
+
i
]
=
cacheCell
{
pos
:
pos
,
sequences
:
[]
int
{
seq
}}
...
...
kvcache/causal_test.go
View file @
0c220935
...
@@ -270,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
...
@@ -270,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context
:=
backend
.
NewContext
()
context
:=
backend
.
NewContext
()
defer
context
.
Close
()
defer
context
.
Close
()
err
:=
cache
.
StartForward
(
context
,
input
.
Options
{
Positions
:
test
.
pos
,
Sequences
:
test
.
seqs
})
err
:=
cache
.
StartForward
(
context
,
input
.
Batch
{
Positions
:
test
.
pos
,
Sequences
:
test
.
seqs
})
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
...
...
kvcache/encoder.go
View file @
0c220935
...
@@ -79,10 +79,10 @@ func (c *EncoderCache) Close() {
...
@@ -79,10 +79,10 @@ func (c *EncoderCache) Close() {
}
}
}
}
func
(
c
*
EncoderCache
)
StartForward
(
ctx
ml
.
Context
,
opts
input
.
Options
)
error
{
func
(
c
*
EncoderCache
)
StartForward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
error
{
// We work with the most recent image
// We work with the most recent image
if
len
(
opts
.
Multimodal
)
>
0
{
if
len
(
batch
.
Multimodal
)
>
0
{
c
.
curPos
=
opts
.
Positions
[
opts
.
Multimodal
[
len
(
opts
.
Multimodal
)
-
1
]
.
Index
]
c
.
curPos
=
batch
.
Positions
[
batch
.
Multimodal
[
len
(
batch
.
Multimodal
)
-
1
]
.
Index
]
}
}
return
nil
return
nil
...
...
kvcache/wrapper.go
View file @
0c220935
...
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
...
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
}
}
}
}
func
(
c
*
WrapperCache
)
StartForward
(
ctx
ml
.
Context
,
opts
input
.
Options
)
error
{
func
(
c
*
WrapperCache
)
StartForward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
error
{
for
i
,
cache
:=
range
c
.
caches
{
for
i
,
cache
:=
range
c
.
caches
{
err
:=
cache
.
StartForward
(
ctx
,
opts
)
err
:=
cache
.
StartForward
(
ctx
,
batch
)
if
err
!=
nil
{
if
err
!=
nil
{
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for
j
:=
i
-
1
;
j
>=
0
;
j
--
{
for
j
:=
i
-
1
;
j
>=
0
;
j
--
{
for
k
:=
range
opts
.
Positions
{
for
k
:=
range
batch
.
Positions
{
_
=
c
.
caches
[
j
]
.
Remove
(
opts
.
Sequences
[
k
],
opts
.
Positions
[
k
],
math
.
MaxInt32
)
_
=
c
.
caches
[
j
]
.
Remove
(
batch
.
Sequences
[
k
],
batch
.
Positions
[
k
],
math
.
MaxInt32
)
}
}
}
}
return
err
return
err
...
...
model/input/input.go
View file @
0c220935
...
@@ -33,11 +33,24 @@ type MultimodalIndex struct {
...
@@ -33,11 +33,24 @@ type MultimodalIndex struct {
Multimodal
any
Multimodal
any
}
}
// Options contains the inputs for a model forward pass
// Batch contains the inputs for a model forward pass
type
Options
struct
{
type
Batch
struct
{
Inputs
[]
int32
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs
[]
int32
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal
[]
MultimodalIndex
Multimodal
[]
MultimodalIndex
Positions
[]
int32
Sequences
[]
int
// Positions is the position for each Input, relative to its sequence. Equal
Outputs
[]
int32
// in length to Inputs.
Positions
[]
int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences
[]
int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs
[]
int32
}
}
model/model.go
View file @
0c220935
...
@@ -26,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
...
@@ -26,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type
Model
interface
{
type
Model
interface
{
Forward
(
ml
.
Context
,
input
.
Options
)
(
ml
.
Tensor
,
error
)
Forward
(
ml
.
Context
,
input
.
Batch
)
(
ml
.
Tensor
,
error
)
Backend
()
ml
.
Backend
Backend
()
ml
.
Backend
Config
()
config
Config
()
config
...
@@ -280,24 +280,24 @@ func canNil(t reflect.Type) bool {
...
@@ -280,24 +280,24 @@ func canNil(t reflect.Type) bool {
t
.
Kind
()
==
reflect
.
Slice
t
.
Kind
()
==
reflect
.
Slice
}
}
func
Forward
(
ctx
ml
.
Context
,
m
Model
,
opts
input
.
Options
)
(
ml
.
Tensor
,
error
)
{
func
Forward
(
ctx
ml
.
Context
,
m
Model
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
if
len
(
opts
.
Positions
)
!=
len
(
opts
.
Sequences
)
{
if
len
(
batch
.
Positions
)
!=
len
(
batch
.
Sequences
)
{
return
nil
,
fmt
.
Errorf
(
"length of positions (%v) must match length of seqs (%v)"
,
len
(
opts
.
Positions
),
len
(
opts
.
Sequences
))
return
nil
,
fmt
.
Errorf
(
"length of positions (%v) must match length of seqs (%v)"
,
len
(
batch
.
Positions
),
len
(
batch
.
Sequences
))
}
}
if
len
(
opts
.
Positions
)
<
1
{
if
len
(
batch
.
Positions
)
<
1
{
return
nil
,
errors
.
New
(
"batch size cannot be less than 1"
)
return
nil
,
errors
.
New
(
"batch size cannot be less than 1"
)
}
}
cache
:=
m
.
Config
()
.
Cache
cache
:=
m
.
Config
()
.
Cache
if
cache
!=
nil
{
if
cache
!=
nil
{
err
:=
cache
.
StartForward
(
ctx
,
opts
)
err
:=
cache
.
StartForward
(
ctx
,
batch
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
}
}
t
,
err
:=
m
.
Forward
(
ctx
,
opts
)
t
,
err
:=
m
.
Forward
(
ctx
,
batch
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
...
model/model_test.go
View file @
0c220935
...
@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
...
@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type
notTextProcessorModel
struct
{}
type
notTextProcessorModel
struct
{}
func
(
notTextProcessorModel
)
Forward
(
ml
.
Context
,
input
.
Options
)
(
ml
.
Tensor
,
error
)
{
func
(
notTextProcessorModel
)
Forward
(
ml
.
Context
,
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
panic
(
"unimplemented"
)
panic
(
"unimplemented"
)
}
}
...
...
model/models/gemma2/model.go
View file @
0c220935
...
@@ -168,18 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
...
@@ -168,18 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return
hiddenState
.
Add
(
ctx
,
residual
)
return
hiddenState
.
Add
(
ctx
,
residual
)
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
opts
input
.
Options
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
inputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Inputs
,
len
(
opts
.
Inputs
))
inputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Inputs
,
len
(
batch
.
Inputs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Positions
,
len
(
opts
.
Positions
))
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Outputs
,
len
(
opts
.
Outputs
))
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Outputs
,
len
(
batch
.
Outputs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
...
model/models/gemma3/model.go
View file @
0c220935
...
@@ -139,23 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
...
@@ -139,23 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return
result
,
nil
return
result
,
nil
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
opts
input
.
Options
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
inputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Inputs
,
len
(
opts
.
Inputs
))
inputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Inputs
,
len
(
batch
.
Inputs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Positions
,
len
(
opts
.
Positions
))
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Outputs
,
len
(
opts
.
Outputs
))
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Outputs
,
len
(
batch
.
Outputs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
return
m
.
TextModel
.
Forward
(
ctx
,
inputs
,
positions
,
outputs
,
opts
,
m
.
Cache
),
nil
return
m
.
TextModel
.
Forward
(
ctx
,
inputs
,
positions
,
outputs
,
batch
,
m
.
Cache
),
nil
}
}
func
init
()
{
func
init
()
{
...
...
model/models/gemma3/model_text.go
View file @
0c220935
...
@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
...
@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return
hiddenState
.
Add
(
ctx
,
residual
)
return
hiddenState
.
Add
(
ctx
,
residual
)
}
}
func
(
m
*
TextModel
)
Forward
(
ctx
ml
.
Context
,
inputs
,
positions
,
outputs
ml
.
Tensor
,
opts
input
.
Options
,
cache
kvcache
.
Cache
)
ml
.
Tensor
{
func
(
m
*
TextModel
)
Forward
(
ctx
ml
.
Context
,
inputs
,
positions
,
outputs
ml
.
Tensor
,
batch
input
.
Batch
,
cache
kvcache
.
Cache
)
ml
.
Tensor
{
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
inputs
)
hiddenState
:=
m
.
TokenEmbedding
.
Forward
(
ctx
,
inputs
)
hiddenState
=
hiddenState
.
Scale
(
ctx
,
math
.
Sqrt
(
float64
(
m
.
TextOptions
.
hiddenSize
)))
hiddenState
=
hiddenState
.
Scale
(
ctx
,
math
.
Sqrt
(
float64
(
m
.
TextOptions
.
hiddenSize
)))
// set image embeddings
// set image embeddings
var
except
[]
int
var
except
[]
int
for
_
,
image
:=
range
opts
.
Multimodal
{
for
_
,
image
:=
range
batch
.
Multimodal
{
visionOutputs
:=
image
.
Multimodal
.
(
ml
.
Tensor
)
visionOutputs
:=
image
.
Multimodal
.
(
ml
.
Tensor
)
ctx
.
Forward
(
visionOutputs
.
Copy
(
ctx
,
hiddenState
.
View
(
ctx
,
image
.
Index
*
hiddenState
.
Stride
(
1
),
visionOutputs
.
Dim
(
0
)
*
visionOutputs
.
Dim
(
1
))))
ctx
.
Forward
(
visionOutputs
.
Copy
(
ctx
,
hiddenState
.
View
(
ctx
,
image
.
Index
*
hiddenState
.
Stride
(
1
),
visionOutputs
.
Dim
(
0
)
*
visionOutputs
.
Dim
(
1
))))
...
...
model/models/llama/model.go
View file @
0c220935
...
@@ -139,18 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
...
@@ -139,18 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return
hiddenState
.
Add
(
ctx
,
residual
)
return
hiddenState
.
Add
(
ctx
,
residual
)
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
opts
input
.
Options
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
inputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Inputs
,
len
(
opts
.
Inputs
))
inputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Inputs
,
len
(
batch
.
Inputs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Positions
,
len
(
opts
.
Positions
))
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Outputs
,
len
(
opts
.
Outputs
))
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Outputs
,
len
(
batch
.
Outputs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
...
model/models/mllama/model.go
View file @
0c220935
...
@@ -135,26 +135,26 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
...
@@ -135,26 +135,26 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return
inputs
,
nil
return
inputs
,
nil
}
}
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
opts
input
.
Options
)
(
ml
.
Tensor
,
error
)
{
func
(
m
*
Model
)
Forward
(
ctx
ml
.
Context
,
batch
input
.
Batch
)
(
ml
.
Tensor
,
error
)
{
var
crossAttentionStates
ml
.
Tensor
var
crossAttentionStates
ml
.
Tensor
if
len
(
opts
.
Multimodal
)
>
0
{
if
len
(
batch
.
Multimodal
)
>
0
{
images
:=
opts
.
Multimodal
[
len
(
opts
.
Multimodal
)
-
1
]
.
Multimodal
.
([]
ml
.
Tensor
)
images
:=
batch
.
Multimodal
[
len
(
batch
.
Multimodal
)
-
1
]
.
Multimodal
.
([]
ml
.
Tensor
)
if
len
(
images
)
>
0
{
if
len
(
images
)
>
0
{
crossAttentionStates
=
images
[
len
(
images
)
-
1
]
crossAttentionStates
=
images
[
len
(
images
)
-
1
]
}
}
}
}
inputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Inputs
,
len
(
opts
.
Inputs
))
inputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Inputs
,
len
(
batch
.
Inputs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Positions
,
len
(
opts
.
Positions
))
positions
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Positions
,
len
(
batch
.
Positions
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
opts
.
Outputs
,
len
(
opts
.
Outputs
))
outputs
,
err
:=
ctx
.
Input
()
.
FromIntSlice
(
batch
.
Outputs
,
len
(
batch
.
Outputs
))
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
...
runner/ollamarunner/runner.go
View file @
0c220935
...
@@ -348,7 +348,7 @@ func (s *Server) processBatch() error {
...
@@ -348,7 +348,7 @@ func (s *Server) processBatch() error {
}
}
defer
s
.
mu
.
Unlock
()
defer
s
.
mu
.
Unlock
()
var
options
input
.
Options
var
batch
input
.
Batch
for
i
,
seq
:=
range
s
.
seqs
{
for
i
,
seq
:=
range
s
.
seqs
{
if
seq
==
nil
{
if
seq
==
nil
{
...
@@ -395,17 +395,17 @@ func (s *Server) processBatch() error {
...
@@ -395,17 +395,17 @@ func (s *Server) processBatch() error {
}
}
}
}
options
.
Inputs
=
append
(
options
.
Inputs
,
inp
.
Token
)
batch
.
Inputs
=
append
(
batch
.
Inputs
,
inp
.
Token
)
if
inp
.
Multimodal
!=
nil
{
if
inp
.
Multimodal
!=
nil
{
options
.
Multimodal
=
append
(
options
.
Multimodal
,
input
.
MultimodalIndex
{
Index
:
len
(
options
.
Inputs
)
-
1
,
Multimodal
:
inp
.
Multimodal
})
batch
.
Multimodal
=
append
(
batch
.
Multimodal
,
input
.
MultimodalIndex
{
Index
:
len
(
batch
.
Inputs
)
-
1
,
Multimodal
:
inp
.
Multimodal
})
}
}
options
.
Positions
=
append
(
options
.
Positions
,
int32
(
len
(
seq
.
cache
.
Inputs
)
+
len
(
seq
.
pendingInputs
)))
batch
.
Positions
=
append
(
batch
.
Positions
,
int32
(
len
(
seq
.
cache
.
Inputs
)
+
len
(
seq
.
pendingInputs
)))
options
.
Sequences
=
append
(
options
.
Sequences
,
seq
.
cache
.
Id
)
batch
.
Sequences
=
append
(
batch
.
Sequences
,
seq
.
cache
.
Id
)
seq
.
iBatch
=
len
(
options
.
Outputs
)
seq
.
iBatch
=
len
(
batch
.
Outputs
)
if
j
+
1
==
len
(
seq
.
inputs
)
{
if
j
+
1
==
len
(
seq
.
inputs
)
{
options
.
Outputs
=
append
(
options
.
Outputs
,
int32
(
len
(
options
.
Inputs
)
-
1
))
batch
.
Outputs
=
append
(
batch
.
Outputs
,
int32
(
len
(
batch
.
Inputs
)
-
1
))
}
}
seq
.
pendingInputs
=
append
(
seq
.
pendingInputs
,
inp
)
seq
.
pendingInputs
=
append
(
seq
.
pendingInputs
,
inp
)
}
}
...
@@ -413,14 +413,14 @@ func (s *Server) processBatch() error {
...
@@ -413,14 +413,14 @@ func (s *Server) processBatch() error {
seq
.
inputs
=
seq
.
inputs
[
len
(
seq
.
pendingInputs
)
:
]
seq
.
inputs
=
seq
.
inputs
[
len
(
seq
.
pendingInputs
)
:
]
}
}
if
len
(
options
.
Inputs
)
==
0
{
if
len
(
batch
.
Inputs
)
==
0
{
return
nil
return
nil
}
}
ctx
:=
s
.
model
.
Backend
()
.
NewContext
()
ctx
:=
s
.
model
.
Backend
()
.
NewContext
()
defer
ctx
.
Close
()
defer
ctx
.
Close
()
modelOutput
,
err
:=
model
.
Forward
(
ctx
,
s
.
model
,
options
)
modelOutput
,
err
:=
model
.
Forward
(
ctx
,
s
.
model
,
batch
)
if
err
!=
nil
{
if
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to decode batch: %w"
,
err
)
return
fmt
.
Errorf
(
"failed to decode batch: %w"
,
err
)
}
}
...
@@ -460,7 +460,7 @@ func (s *Server) processBatch() error {
...
@@ -460,7 +460,7 @@ func (s *Server) processBatch() error {
}
}
// sample a token
// sample a token
vocabSize
:=
len
(
logits
)
/
len
(
options
.
Outputs
)
vocabSize
:=
len
(
logits
)
/
len
(
batch
.
Outputs
)
token
,
err
:=
seq
.
sampler
.
Sample
(
logits
[
seq
.
iBatch
*
vocabSize
:
(
seq
.
iBatch
+
1
)
*
vocabSize
])
token
,
err
:=
seq
.
sampler
.
Sample
(
logits
[
seq
.
iBatch
*
vocabSize
:
(
seq
.
iBatch
+
1
)
*
vocabSize
])
if
err
!=
nil
{
if
err
!=
nil
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment