Unverified Commit bbec01c9 authored by Qubitium-modelcloud's avatar Qubitium-modelcloud Committed by GitHub
Browse files

Fix tp worker only checking req[0] for stream (#546)

parent 40e53d65
...@@ -303,6 +303,10 @@ class Batch: ...@@ -303,6 +303,10 @@ class Batch:
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
# whether batch has at least 1 streaming request
def has_stream(self) -> bool:
return any(r.stream for r in self.reqs)
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda" device = "cuda"
bs = len(self.reqs) bs = len(self.reqs)
......
...@@ -5,7 +5,7 @@ import logging ...@@ -5,7 +5,7 @@ import logging
import time import time
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List from typing import List, Optional
import rpyc import rpyc
import torch import torch
...@@ -253,7 +253,7 @@ class ModelTpServer: ...@@ -253,7 +253,7 @@ class ModelTpServer:
self.running_batch = None self.running_batch = None
break break
if self.out_pyobjs and self.running_batch.reqs[0].stream: if self.out_pyobjs and self.running_batch.has_stream():
break break
else: else:
# Check the available size # Check the available size
...@@ -314,7 +314,7 @@ class ModelTpServer: ...@@ -314,7 +314,7 @@ class ModelTpServer:
) )
self.forward_queue.append(req) self.forward_queue.append(req)
def get_new_fill_batch(self): def get_new_fill_batch(self) -> Optional[Batch]:
if ( if (
self.running_batch is not None self.running_batch is not None
and len(self.running_batch.reqs) > self.max_running_requests and len(self.running_batch.reqs) > self.max_running_requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment