Commit 59488cc9 authored by lizhigong's avatar lizhigong
Browse files

fix 修改tbo的线程管理和线程释放方式,减少对其他模块的影响

parent bd363067
...@@ -78,8 +78,6 @@ async def serve_http(app: FastAPI, ...@@ -78,8 +78,6 @@ async def serve_http(app: FastAPI,
port, process, " ".join(process.cmdline())) port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.") logger.info("Shutting down FastAPI HTTP server.")
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
finish_two_batch_overlap()
return server.shutdown() return server.shutdown()
finally: finally:
watchdog_task.cancel() watchdog_task.cancel()
......
...@@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest ...@@ -16,7 +16,6 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
...@@ -144,7 +143,6 @@ class ExecutorBase(ABC): ...@@ -144,7 +143,6 @@ class ExecutorBase(ABC):
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop.""" """Releases parallel workers from model loop."""
finish_two_batch_overlap()
return return
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
...@@ -303,7 +301,6 @@ class DistributedExecutorBase(ExecutorBase): ...@@ -303,7 +301,6 @@ class DistributedExecutorBase(ExecutorBase):
return driver_outputs return driver_outputs
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
finish_two_batch_overlap()
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
return return
......
...@@ -256,10 +256,6 @@ def _run_worker_process( ...@@ -256,10 +256,6 @@ def _run_worker_process(
and not tunable.record_untuned_is_enabled()): and not tunable.record_untuned_is_enabled()):
tunable.write_file() tunable.write_file()
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
finish_two_batch_overlap()
logger.info("Worker exiting") logger.info("Worker exiting")
......
...@@ -53,24 +53,17 @@ class TwoBatchOverlap(): ...@@ -53,24 +53,17 @@ class TwoBatchOverlap():
def init_tbo_thread(self): def init_tbo_thread(self):
self.model_input_left_queue.empty() self.model_input_left_queue.empty()
self.model_input_right_queue.empty() self.model_input_right_queue.empty()
if self.left_thread == None:
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,)) self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
self.left_thread.start() self.left_thread.start()
if self.right_thread == None:
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start() self.right_thread.start()
logger.info('tbo:two batch overlap threads start') logger.info('tbo:two batch overlap start')
def finish_thread(self): def finish_thread(self):
if self.left_thread != None:
self.model_input_left_queue.put(None)
self.left_thread.join() self.left_thread.join()
self.left_thread = None self.left_thread = None
if self.right_thread != None:
self.model_input_right_queue.put(None)
self.right_thread.join() self.right_thread.join()
self.right_thread = None self.right_thread = None
logger.info('tbo:finish threads')
@torch.inference_mode() @torch.inference_mode()
def thread_two_batch_overlap(self, queue): def thread_two_batch_overlap(self, queue):
...@@ -84,10 +77,7 @@ class TwoBatchOverlap(): ...@@ -84,10 +77,7 @@ class TwoBatchOverlap():
self.right_tid = tid self.right_tid = tid
init_tbo_forward_context(False, self.right_tid) init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream): with torch.cuda.stream(tbo_step_stream):
while True:
model_input = queue.get() model_input = queue.get()
if model_input == None:
break
profile.ProfRangePush('start') profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid) self.tbo_thread_synchronize(tid)
model_kwargs = None model_kwargs = None
...@@ -100,7 +90,6 @@ class TwoBatchOverlap(): ...@@ -100,7 +90,6 @@ class TwoBatchOverlap():
intermediate_tensors = self.intermediate_tensors_right intermediate_tensors = self.intermediate_tensors_right
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine): self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable( hidden_or_intermediate_states = self.model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
...@@ -122,10 +111,10 @@ class TwoBatchOverlap(): ...@@ -122,10 +111,10 @@ class TwoBatchOverlap():
if tid == self.left_tid: if tid == self.left_tid:
if not self.left_first: if not self.left_first:
self.sem_right.release() self.sem_right.release()
self.left_first = False
profile.ProfRangePop() profile.ProfRangePop()
self.sem_left.acquire() self.sem_left.acquire()
profile.ProfRangePush('left') profile.ProfRangePush('left')
self.left_first = False
return self.event_left_c2t, self.event_left_t2c return self.event_left_c2t, self.event_left_t2c
else: else:
self.sem_left.release() self.sem_left.release()
...@@ -147,8 +136,6 @@ class TwoBatchOverlap(): ...@@ -147,8 +136,6 @@ class TwoBatchOverlap():
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
model_kwargs_left, model_kwargs_left,
model_kwargs_right): model_kwargs_right):
if self.left_thread == None:
self.init_tbo_thread()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.virtual_engine = virtual_engine self.virtual_engine = virtual_engine
self.model_executable = model_executable self.model_executable = model_executable
...@@ -186,16 +173,10 @@ class TwoBatchOverlap(): ...@@ -186,16 +173,10 @@ class TwoBatchOverlap():
tbo_obj = None tbo_obj = None
def init_two_batch_overlap(): def init_two_batch_overlap():
if envs.VLLM_ENABLE_TBO:
global tbo_obj global tbo_obj
if tbo_obj == None: if tbo_obj == None:
tbo_obj = TwoBatchOverlap() tbo_obj = TwoBatchOverlap()
tbo_obj.init_tbo_thread()
def finish_two_batch_overlap():
global tbo_obj
if tbo_obj != None:
tbo_obj.finish_thread()
tbo_obj = None
def tbo_all_reduce(obj): def tbo_all_reduce(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running: if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
...@@ -309,6 +290,7 @@ def tbo_model_executable( ...@@ -309,6 +290,7 @@ def tbo_model_executable(
hidden_or_intermediate_states = merge_model_output(states_left, states_right) hidden_or_intermediate_states = merge_model_output(states_left, states_right)
tbo_obj.tbo_running = False tbo_obj.tbo_running = False
tbo_obj.step_event.record() tbo_obj.step_event.record()
tbo_obj.finish_thread()
current_stream.wait_event(tbo_obj.step_event) current_stream.wait_event(tbo_obj.step_event)
profile.ProfRangePop() profile.ProfRangePop()
return hidden_or_intermediate_states return hidden_or_intermediate_states
...@@ -18,7 +18,6 @@ from vllm.logger import init_logger ...@@ -18,7 +18,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method, resolve_obj_by_qualname, run_method,
update_environment_variables, update_environment_variables,
...@@ -113,7 +112,6 @@ class WorkerBase: ...@@ -113,7 +112,6 @@ class WorkerBase:
while True: while True:
output = self.execute_model(execute_model_req=None) output = self.execute_model(execute_model_req=None)
if output is None: if output is None:
finish_two_batch_overlap()
return None return None
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment