"docs/source/en/using-diffusers/write_own_pipeline.md" did not exist on "174dcd697faf88370f1e7b2eeabb059dd8f1b2f4"
Commit 26465fb8 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

ollamarunner: Worst case batch for token generation

We currently allocate the worst case batch for max sized
batches, which corresponds to prompt processing. However,
there are some cases where the generated graph is different
for small and large batches. To ensure that we don't need
to allocate memory later after layout has taken place, we
should run the worst case batch both ways and take the larger
amount of memory.

This does not noticeably affect loading speed as the most expensive
part of this logic is from image processing and that does not
occur during token generation.
parent 88236bc0
...@@ -1009,12 +1009,17 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) { ...@@ -1009,12 +1009,17 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Server) reserveWorstCaseGraph() error { func (s *Server) reserveWorstCaseGraph(prompt bool) error {
ctx := s.model.Backend().NewContext() ctx := s.model.Backend().NewContext()
defer ctx.Close() defer ctx.Close()
var err error var err error
inputs := make([]*input.Input, s.batchSize) batchSize := 1
if prompt {
batchSize = s.batchSize
}
inputs := make([]*input.Input, batchSize)
for i := range inputs { for i := range inputs {
inputs[i] = &input.Input{} inputs[i] = &input.Input{}
} }
...@@ -1031,7 +1036,7 @@ func (s *Server) reserveWorstCaseGraph() error { ...@@ -1031,7 +1036,7 @@ func (s *Server) reserveWorstCaseGraph() error {
// - The result may now be larger than a batch (images may not fit in a // - The result may now be larger than a batch (images may not fit in a
// single batch), so trim based on what will fit and must be grouped together. // single batch), so trim based on what will fit and must be grouped together.
// - Fill out the rest of the space with text tokens. // - Fill out the rest of the space with text tokens.
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok { if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); prompt && ok {
mmCtx := s.model.Backend().NewContext() mmCtx := s.model.Backend().NewContext()
defer mmCtx.Close() defer mmCtx.Close()
...@@ -1058,10 +1063,10 @@ func (s *Server) reserveWorstCaseGraph() error { ...@@ -1058,10 +1063,10 @@ func (s *Server) reserveWorstCaseGraph() error {
} }
} }
if len(inputs) < s.batchSize { if len(inputs) < batchSize {
newInputs := make([]*input.Input, s.batchSize) newInputs := make([]*input.Input, batchSize)
copy(newInputs, inputs) copy(newInputs, inputs)
for i := len(inputs); i < s.batchSize; i++ { for i := len(inputs); i < batchSize; i++ {
newInputs[i] = &input.Input{} newInputs[i] = &input.Input{}
} }
inputs = newInputs inputs = newInputs
...@@ -1160,7 +1165,12 @@ func (s *Server) allocModel( ...@@ -1160,7 +1165,12 @@ func (s *Server) allocModel(
s.seqs = make([]*Sequence, s.parallel) s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
return s.reserveWorstCaseGraph() err = s.reserveWorstCaseGraph(true)
if err != nil {
return nil
}
return s.reserveWorstCaseGraph(false)
} }
// closeModel frees all memory associated with a model // closeModel frees all memory associated with a model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment