Commit 1a2feb2a authored by Michael Yang's avatar Michael Yang Committed by Michael Yang
Browse files

ollamarunner: fix deadlock

hardErrCh will deadlock since forwardBatch is blocked on
computeStartedCh which never gets sent. since the response to
hardErrCh is to panic, just panic instead
parent aab21904
...@@ -321,9 +321,6 @@ type Server struct { ...@@ -321,9 +321,6 @@ type Server struct {
// TODO (jmorganca): make this n_batch // TODO (jmorganca): make this n_batch
batchSize int batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// Simple counter used only for trace logging batches // Simple counter used only for trace logging batches
batchID int batchID int
...@@ -411,8 +408,6 @@ func (s *Server) run(ctx context.Context) { ...@@ -411,8 +408,6 @@ func (s *Server) run(ctx context.Context) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case err := <-s.hardErrCh:
panic(err)
default: default:
var err error var err error
nextBatch, err := s.forwardBatch(previousBatch) nextBatch, err := s.forwardBatch(previousBatch)
...@@ -663,9 +658,7 @@ func (s *Server) computeBatch(activeBatch batchState) { ...@@ -663,9 +658,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// don't sample prompt processing // don't sample prompt processing
if len(seq.inputs) != 0 { if len(seq.inputs) != 0 {
if !s.cache.enabled { if !s.cache.enabled {
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch") panic("caching disabled but unable to fit entire input in a batch")
s.mu.Unlock()
return
} }
continue continue
} }
...@@ -720,8 +713,7 @@ func (s *Server) computeBatch(activeBatch batchState) { ...@@ -720,8 +713,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil { if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) panic("failed to sample token")
return
} }
nextBatchTokens[i].Token = token nextBatchTokens[i].Token = token
...@@ -738,8 +730,7 @@ func (s *Server) computeBatch(activeBatch batchState) { ...@@ -738,8 +730,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil { if err != nil {
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err) panic("failed to decode token")
return
} }
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
...@@ -1321,7 +1312,6 @@ func Execute(args []string) error { ...@@ -1321,7 +1312,6 @@ func Execute(args []string) error {
server := &Server{ server := &Server{
modelPath: *mpath, modelPath: *mpath,
status: llm.ServerStatusLaunched, status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
} }
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment