Unverified Commit 80b6080d authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix async scheduling + chunked prefill + preemption (#28787)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent 03ee4811
...@@ -65,9 +65,8 @@ def test_without_spec_decoding( ...@@ -65,9 +65,8 @@ def test_without_spec_decoding(
(True, "mp", True, None, False), (True, "mp", True, None, False),
(True, "uni", True, None, False), (True, "uni", True, None, False),
(False, "mp", True, None, True), (False, "mp", True, None, True),
# Async scheduling + preemption + chunked prefill needs to be fixed (WIP) (True, "mp", True, None, True),
# (True, "mp", True, None, True), (True, "uni", True, None, True),
# (True, "uni", True, None, True),
] ]
run_tests( run_tests(
...@@ -103,9 +102,8 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): ...@@ -103,9 +102,8 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(False, "mp", True, spec_config_short, True), (False, "mp", True, spec_config_short, True),
(True, "uni", True, spec_config, False), (True, "uni", True, spec_config, False),
(True, "uni", True, spec_config_short, False), (True, "uni", True, spec_config_short, False),
# Async scheduling + preemption + chunked prefill needs to be fixed (WIP) (True, "mp", True, spec_config, True),
# (True, "mp", True, spec_config, True), (True, "uni", True, spec_config_short, True),
# (True, "uni", True, spec_config_short, True),
] ]
run_tests( run_tests(
......
...@@ -778,9 +778,7 @@ class Scheduler(SchedulerInterface): ...@@ -778,9 +778,7 @@ class Scheduler(SchedulerInterface):
assert not scheduled_in_prev_step assert not scheduled_in_prev_step
resumed_req_ids.add(req_id) resumed_req_ids.add(req_id)
if not scheduled_in_prev_step: if not scheduled_in_prev_step:
all_token_ids[req_id] = req.all_token_ids[ all_token_ids[req_id] = req.all_token_ids.copy()
: req.num_computed_tokens + num_tokens
]
new_block_ids.append( new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True) req_to_new_blocks[req_id].get_block_ids(allow_none=True)
) )
......
...@@ -97,6 +97,9 @@ class ConstantList(Generic[T], Sequence): ...@@ -97,6 +97,9 @@ class ConstantList(Generic[T], Sequence):
def __repr__(self): def __repr__(self):
return f"ConstantList({self._x})" return f"ConstantList({self._x})"
def copy(self) -> list[T]:
return self._x.copy()
class CpuGpuBuffer: class CpuGpuBuffer:
"""Buffer to easily copy tensors between CPU and GPU.""" """Buffer to easily copy tensors between CPU and GPU."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment