Unverified Commit caac5c2e authored by courage17340's avatar courage17340 Committed by GitHub
Browse files

[Bugfix][Core] fix abort_seq_group and memory leak when n>1 (#14326)


Signed-off-by: default avatarcourage17340 <courage17340@163.com>
parent 6bd1dd9d
...@@ -16,8 +16,9 @@ from vllm.logger import init_logger ...@@ -16,8 +16,9 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceGroupBase, SequenceGroupMetadata,
SequenceStage, SequenceStatus) SequenceGroupMetadataDelta, SequenceStage,
SequenceStatus)
from vllm.utils import Device, PyObjectCache from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -561,7 +562,11 @@ class Scheduler: ...@@ -561,7 +562,11 @@ class Scheduler:
# Only for testing purposes. # Only for testing purposes.
self.swapped.append(seq_group) self.swapped.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: def abort_seq_group(
self,
request_id: Union[str, Iterable[str]],
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
) -> None:
"""Aborts a sequence group with the given ID. """Aborts a sequence group with the given ID.
Check if the sequence group with the given ID Check if the sequence group with the given ID
...@@ -573,21 +578,29 @@ class Scheduler: ...@@ -573,21 +578,29 @@ class Scheduler:
Args: Args:
request_id: The ID(s) of the sequence group to abort. request_id: The ID(s) of the sequence group to abort.
seq_id_to_seq_group: helper for groups with n>1
""" """
if isinstance(request_id, str): if isinstance(request_id, str):
request_id = (request_id, ) request_id = (request_id, )
request_ids = set(request_id) request_ids = set(request_id)
seq_id_to_seq_group = seq_id_to_seq_group or {}
for state_queue in [self.waiting, self.running, self.swapped]: for state_queue in [self.waiting, self.running, self.swapped]:
aborted_groups: List[SequenceGroup] = [] aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue: for seq_group in state_queue:
if not request_ids: # When n>1, seq_group.request_id looks like
# Using 'break' here may add two extra iterations, # foo_parallel_sample_0, while request_ids is just foo, and we
# but is acceptable to reduce complexity. # should resolve it as real_request_id to match.
break if seq_group.request_id in seq_id_to_seq_group:
if seq_group.request_id in request_ids: real_request_id = seq_id_to_seq_group[
seq_group.request_id].group_id
else:
real_request_id = seq_group.request_id
if real_request_id in request_ids:
# Appending aborted group into pending list. # Appending aborted group into pending list.
aborted_groups.append(seq_group) aborted_groups.append(seq_group)
request_ids.remove(seq_group.request_id) # We can't remove real_request_id in request_ids here,
# because there may be other seq groups sharing the same
# real_request_id
for aborted_group in aborted_groups: for aborted_group in aborted_groups:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(aborted_group) state_queue.remove(aborted_group)
...@@ -598,6 +611,8 @@ class Scheduler: ...@@ -598,6 +611,8 @@ class Scheduler:
continue continue
seq.status = SequenceStatus.FINISHED_ABORTED seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq) self.free_seq(seq)
if aborted_group.request_id in seq_id_to_seq_group:
del seq_id_to_seq_group[aborted_group.request_id]
self._free_seq_group_cross_attn_blocks(aborted_group) self._free_seq_group_cross_attn_blocks(aborted_group)
......
...@@ -887,7 +887,8 @@ class LLMEngine: ...@@ -887,7 +887,8 @@ class LLMEngine:
>>> engine.abort_request(request_id) >>> engine.abort_request(request_id)
""" """
for scheduler in self.scheduler: for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id) scheduler.abort_seq_group(
request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
def get_model_config(self) -> ModelConfig: def get_model_config(self) -> ModelConfig:
"""Gets the model configuration.""" """Gets the model configuration."""
...@@ -1354,6 +1355,11 @@ class LLMEngine: ...@@ -1354,6 +1355,11 @@ class LLMEngine:
finished_requests_ids = self.scheduler[ finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids() virtual_engine].get_and_reset_finished_requests_ids()
# When n>1, elements in self.seq_id_to_seq_group should be deleted
# here, otherwise memory leaks.
for finished_request_id in finished_requests_ids:
if finished_request_id in self.seq_id_to_seq_group:
del self.seq_id_to_seq_group[finished_request_id]
# Maybe switch from async mode to sync mode # Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0: if not allow_async_output_proc and len(ctx.output_queue) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment