"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "24b8b5cf5e44b585e8808c34616cfa20e8e39866"
Commit a8d9c264 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

llamarunner: Record the time for all batches during prompt processing

Currently, we only record the time for the last batch when processing
the prompt. This results in unrealistically high numbers for the
old llama runner.

Before:
total duration:       31.273112939s
load duration:        4.97054657s
prompt eval count:    32768 token(s)
prompt eval duration: 235.137439ms
prompt eval rate:     139356.80 tokens/s
eval count:           1873 token(s)
eval duration:        18.173182374s
eval rate:            103.06 tokens/s

After:
total duration:       30.024798033s
load duration:        4.758588663s
prompt eval count:    32768 token(s)
prompt eval duration: 7.779621548s
prompt eval rate:     4212.03 tokens/s
eval count:           1769 token(s)
eval duration:        17.148014223s
eval rate:            103.16 tokens/s
parent 0334e67f
...@@ -384,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -384,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
defer s.mu.Unlock() defer s.mu.Unlock()
var batch *llama.Batch var batch *llama.Batch
var numOutputs int
seqIdx := s.nextSeq - 1 seqIdx := s.nextSeq - 1
for range s.seqs { for range s.seqs {
...@@ -446,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -446,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
break break
} }
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id) output := i+1 == len(seq.inputs)
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id)
if output {
numOutputs++
}
seq.pendingInputs = append(seq.pendingInputs, input) seq.pendingInputs = append(seq.pendingInputs, input)
seq.iBatch = batch.NumTokens() - 1 seq.iBatch = batch.NumTokens() - 1
} }
...@@ -463,6 +469,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -463,6 +469,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return fmt.Errorf("failed to decode batch: %w", err) return fmt.Errorf("failed to decode batch: %w", err)
} }
if numOutputs > 0 {
s.lc.Synchronize()
}
for i, seq := range s.seqs { for i, seq := range s.seqs {
if seq == nil { if seq == nil {
continue continue
...@@ -476,10 +486,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -476,10 +486,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// don't sample prompt processing // don't sample prompt processing
if len(seq.inputs) != 0 { if len(seq.inputs) != 0 {
seq.processingDuration += time.Since(t)
continue continue
} }
s.lc.Synchronize()
seq.numDecoded++ seq.numDecoded++
if seq.numDecoded > 1 { if seq.numDecoded > 1 {
seq.generationDuration += time.Since(t) seq.generationDuration += time.Since(t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment