Unverified Commit 2863befc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Optimization] Use Shared `CachedRequestData` Instance Across All Requests (#20232)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 2965c99c
...@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, ...@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
...@@ -198,7 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool], ...@@ -198,7 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling # Test initial scheduling
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled. # Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items(): for req_id, num_tokens in output.num_scheduled_tokens.items():
...@@ -225,7 +225,7 @@ def test_schedule_multimodal_requests(): ...@@ -225,7 +225,7 @@ def test_schedule_multimodal_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items(): for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids) assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
...@@ -259,7 +259,7 @@ def test_schedule_partial_requests(): ...@@ -259,7 +259,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3 assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert scheduler.max_num_encoder_input_tokens == 1024 assert scheduler.max_num_encoder_input_tokens == 1024
...@@ -295,7 +295,7 @@ def test_schedule_partial_requests(): ...@@ -295,7 +295,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output.scheduled_new_reqs) == 0 assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 2 assert output.scheduled_cached_reqs.num_reqs == 2
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 1 assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700 assert output.num_scheduled_tokens[requests[1].request_id] == 700
...@@ -319,7 +319,7 @@ def test_no_mm_input_chunking(): ...@@ -319,7 +319,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1 assert len(output.scheduled_new_reqs) == 1
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# We want to only see the 400 text tokens at the start scheduled # We want to only see the 400 text tokens at the start scheduled
assert output.num_scheduled_tokens[requests[0].request_id] == 400 assert output.num_scheduled_tokens[requests[0].request_id] == 400
...@@ -342,7 +342,7 @@ def test_no_mm_input_chunking(): ...@@ -342,7 +342,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule() output = scheduler.schedule()
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(output.scheduled_new_reqs) == 0 assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 1 assert output.scheduled_cached_reqs.num_reqs == 1
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 800 assert output.num_scheduled_tokens[requests[0].request_id] == 800
...@@ -379,7 +379,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -379,7 +379,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3 assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400. # The first request is scheduled partially - 400.
...@@ -408,7 +408,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -408,7 +408,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output1 = scheduler.schedule() output1 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0 assert len(output1.scheduled_new_reqs) == 0
assert len(output1.scheduled_cached_reqs) == 3 assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0 assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400
...@@ -430,7 +430,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -430,7 +430,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output2 = scheduler.schedule() output2 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0 assert len(output2.scheduled_new_reqs) == 0
assert len(output2.scheduled_cached_reqs) == 3 assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0 assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1
...@@ -449,23 +449,24 @@ def test_stop_via_update_from_output(): ...@@ -449,23 +449,24 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 1, num_scheduled_tokens={
requests[1].request_id: 2 requests[0].request_id: 1,
}, requests[1].request_id: 2
total_num_scheduled_tokens=3, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=3,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [], scheduled_spec_decode_tokens={
requests[1].request_id: [10] requests[0].request_id: [],
}, requests[1].request_id: [10]
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -501,23 +502,25 @@ def test_stop_via_update_from_output(): ...@@ -501,23 +502,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 3, num_scheduled_tokens={
requests[1].request_id: 2 requests[0].request_id: 3,
}, requests[1].request_id: 2
total_num_scheduled_tokens=5, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=5,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [10, 42], scheduled_spec_decode_tokens={
requests[1].request_id: [13] requests[0].request_id: [10, 42],
}, requests[1].request_id: [13]
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -551,23 +554,25 @@ def test_stop_via_update_from_output(): ...@@ -551,23 +554,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 3, num_scheduled_tokens={
requests[1].request_id: 1 requests[0].request_id: 3,
}, requests[1].request_id: 1
total_num_scheduled_tokens=4, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=4,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [10, 11], scheduled_spec_decode_tokens={
requests[1].request_id: [] requests[0].request_id: [10, 11],
}, requests[1].request_id: []
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -603,7 +608,7 @@ def test_stop_via_update_from_output(): ...@@ -603,7 +608,7 @@ def test_stop_via_update_from_output():
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3}, num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3, total_num_scheduled_tokens=3,
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
...@@ -1208,7 +1213,6 @@ def assert_scheduler_empty(scheduler: Scheduler): ...@@ -1208,7 +1213,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager. # EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0 assert len(scheduler.encoder_cache_manager.freed) == 0
......
...@@ -66,7 +66,7 @@ def test_basic_lifecycle(): ...@@ -66,7 +66,7 @@ def test_basic_lifecycle():
assert len(scheduler_output.finished_req_ids) == 1 assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model() # (2b): execute_model()
...@@ -81,7 +81,7 @@ def test_basic_lifecycle(): ...@@ -81,7 +81,7 @@ def test_basic_lifecycle():
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0 assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model() # (3b): execute_model()
......
...@@ -36,7 +36,7 @@ def test_basic_lifecycle(): ...@@ -36,7 +36,7 @@ def test_basic_lifecycle():
# Nothing running and empty scheduler output. # Nothing running and empty scheduler output.
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0 assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0 assert scheduler_output.total_num_scheduled_tokens == 0
...@@ -158,7 +158,7 @@ def test_interleaved_lifecycle(): ...@@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 1 assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b]) [request_local_a, request_local_b])
...@@ -169,7 +169,7 @@ def test_interleaved_lifecycle(): ...@@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b]) reqs=[request_local_a, request_local_b])
...@@ -177,14 +177,14 @@ def test_interleaved_lifecycle(): ...@@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
# STEP 4: KVs arrive. # STEP 4: KVs arrive.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b], [request_local_a, request_local_b],
...@@ -196,7 +196,7 @@ def test_interleaved_lifecycle(): ...@@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote]) [request_local_a, request_local_b, request_remote])
......
...@@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler): ...@@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0 assert len(scheduler.finished_recving_kv_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager. # EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0 assert len(scheduler.encoder_cache_manager.freed) == 0
......
...@@ -82,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -82,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput( return SchedulerOutput(
scheduled_new_reqs=new_reqs, scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -161,7 +161,7 @@ def test_update_states_request_finished(model_runner): ...@@ -161,7 +161,7 @@ def test_update_states_request_finished(model_runner):
# finish req # finish req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={}, num_scheduled_tokens={},
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -191,7 +191,7 @@ def test_update_states_request_resumed(model_runner): ...@@ -191,7 +191,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req # unschedule req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={}, num_scheduled_tokens={},
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -209,16 +209,16 @@ def test_update_states_request_resumed(model_runner): ...@@ -209,16 +209,16 @@ def test_update_states_request_resumed(model_runner):
# resume req # resume req
cached_req_data = CachedRequestData( cached_req_data = CachedRequestData(
req_id=req_id, req_ids=[req_id],
resumed_from_preemption=False, resumed_from_preemption=[False],
new_token_ids=[], new_token_ids=[[]],
new_block_ids=([], ), new_block_ids=[([], )],
num_computed_tokens=0, num_computed_tokens=[0],
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data], scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1}, num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -249,7 +249,7 @@ def test_update_states_no_changes(model_runner): ...@@ -249,7 +249,7 @@ def test_update_states_no_changes(model_runner):
# schedule req # schedule req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1}, num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -284,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner): ...@@ -284,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1 # unschedule req_1
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1}, num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
......
...@@ -133,7 +133,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -133,7 +133,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput( return SchedulerOutput(
scheduled_new_reqs=new_reqs, scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -199,7 +199,7 @@ def test_update_states_request_finished(model_runner): ...@@ -199,7 +199,7 @@ def test_update_states_request_finished(model_runner):
# finish req # finish req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={}, num_scheduled_tokens={},
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -231,7 +231,7 @@ def test_update_states_request_resumed(model_runner): ...@@ -231,7 +231,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req # unschedule req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={}, num_scheduled_tokens={},
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -249,16 +249,16 @@ def test_update_states_request_resumed(model_runner): ...@@ -249,16 +249,16 @@ def test_update_states_request_resumed(model_runner):
# resume req # resume req
cached_req_data = CachedRequestData( cached_req_data = CachedRequestData(
req_id=req_id, req_ids=[req_id],
resumed_from_preemption=False, resumed_from_preemption=[False],
new_token_ids=[], new_token_ids=[[]],
new_block_ids=([], ), new_block_ids=([[0]], ),
num_computed_tokens=0, num_computed_tokens=[0],
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data], scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1}, num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -339,7 +339,7 @@ def test_update_states_no_changes(model_runner): ...@@ -339,7 +339,7 @@ def test_update_states_no_changes(model_runner):
# schedule req # schedule req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1}, num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
...@@ -376,7 +376,7 @@ def test_update_states_request_unscheduled(model_runner): ...@@ -376,7 +376,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1 # unschedule req_1
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1}, num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
......
...@@ -371,45 +371,48 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -371,45 +371,48 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_size=self._block_size) block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id) self._requests_need_load.pop(new_req.req_id)
for cached_req in scheduler_output.scheduled_cached_reqs: cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
if self.is_producer: if self.is_producer:
num_scheduled_tokens = ( num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[cached_req.req_id] scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_tokens = (num_scheduled_tokens + num_computed_tokens)
cached_req.num_computed_tokens) assert req_id in self.chunked_prefill
assert cached_req.req_id in self.chunked_prefill block_ids = new_block_ids[0]
block_ids = cached_req.new_block_ids[0] if not resumed_from_preemption:
if not cached_req.resumed_from_preemption: block_ids = (self.chunked_prefill[req_id][0] + block_ids)
block_ids = (self.chunked_prefill[cached_req.req_id][0] + prompt_token_ids = self.chunked_prefill[req_id][1]
block_ids)
prompt_token_ids = self.chunked_prefill[cached_req.req_id][1]
# the request's prompt is chunked prefill again # the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids): if num_tokens < len(prompt_token_ids):
self.chunked_prefill[cached_req.req_id] = ( self.chunked_prefill[req_id] = (block_ids,
block_ids, prompt_token_ids) prompt_token_ids)
continue continue
# the request's prompt is all prefilled finally # the request's prompt is all prefilled finally
meta.add_request(request_id=cached_req.req_id, meta.add_request(request_id=req_id,
token_ids=prompt_token_ids, token_ids=prompt_token_ids,
block_ids=block_ids, block_ids=block_ids,
block_size=self._block_size) block_size=self._block_size)
self.chunked_prefill.pop(cached_req.req_id, None) self.chunked_prefill.pop(req_id, None)
continue continue
# NOTE(rob): here we rely on the resumed requests being # NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs. # the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption: if not resumed_from_preemption:
break break
if cached_req.req_id in self._requests_need_load: if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(cached_req.req_id) request, _ = self._requests_need_load.pop(req_id)
total_tokens = cached_req.num_computed_tokens + 1 total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens] token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all # NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request. # of the block_ids for the request.
block_ids = cached_req.new_block_ids[0] block_ids = new_block_ids[0]
meta.add_request(request_id=cached_req.req_id, meta.add_request(request_id=req_id,
token_ids=token_ids, token_ids=token_ids,
block_ids=block_ids, block_ids=block_ids,
block_size=self._block_size) block_size=self._block_size)
......
...@@ -304,23 +304,28 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -304,23 +304,28 @@ class SharedStorageConnector(KVConnectorBase_V1):
block_size=self._block_size, block_size=self._block_size,
is_store=True) is_store=True)
for cached_req in scheduler_output.scheduled_cached_reqs: cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_token_ids = cached_reqs.new_token_ids[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
# NOTE(rob): here we rely on the resumed requests being # NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs. # the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption: if not resumed_from_preemption:
break break
if cached_req.req_id in self._requests_need_load: if req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full # NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it # list of token ids (only new tokens). So we look it
# up in the actual request object. # up in the actual request object.
request = self._requests_need_load[cached_req.req_id] request = self._requests_need_load[req_id]
total_tokens = (len(cached_req.new_token_ids) + total_tokens = (len(new_token_ids) + num_computed_tokens)
cached_req.num_computed_tokens)
token_ids = request.all_token_ids[:total_tokens] token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all # NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request. # of the block_ids for the request.
block_ids = cached_req.new_block_ids[0] block_ids = new_block_ids[0]
meta.add_request(token_ids=token_ids, meta.add_request(token_ids=token_ids,
block_ids=block_ids, block_ids=block_ids,
......
...@@ -83,29 +83,27 @@ class NewRequestData: ...@@ -83,29 +83,27 @@ class NewRequestData:
@dataclass @dataclass
class CachedRequestData: class CachedRequestData:
req_id: str req_ids: list[str]
# If resumed_from_preemption is False, new_block_ids will be appended to # If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the # the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs. # request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool resumed_from_preemption: list[bool]
new_token_ids: list[int] new_token_ids: list[list[int]]
new_block_ids: tuple[list[int], ...] new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: int num_computed_tokens: list[int]
@property
def num_reqs(self) -> int:
return len(self.req_ids)
@classmethod @classmethod
def from_request( def make_empty(cls) -> CachedRequestData:
cls,
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: tuple[list[int], ...],
) -> CachedRequestData:
return cls( return cls(
req_id=request.request_id, req_ids=[],
resumed_from_preemption=resumed_from_preemption, resumed_from_preemption=[],
new_token_ids=new_token_ids, new_token_ids=[],
new_block_ids=new_block_ids, new_block_ids=[],
num_computed_tokens=request.num_computed_tokens, num_computed_tokens=[],
) )
...@@ -119,7 +117,7 @@ class SchedulerOutput: ...@@ -119,7 +117,7 @@ class SchedulerOutput:
# list of the requests that have been scheduled before. # list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes, # Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost. # we only send the diff to minimize the communication cost.
scheduled_cached_reqs: list[CachedRequestData] scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens # req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request. # Number of tokens scheduled for each request.
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
from __future__ import annotations from __future__ import annotations
import itertools
import time import time
from collections import defaultdict, deque from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union
...@@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface): ...@@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface):
# KV Connector: requests in process of async KV loading or recving # KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set() self.finished_recving_kv_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> deque of CachedRequestData
self._cached_reqs_data: dict[
str, deque[CachedRequestData]] = defaultdict(deque)
# Encoder-related. # Encoder-related.
# Calculate encoder cache size if applicable # Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space. # NOTE: For now we use the same budget for both compute and space.
...@@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface): ...@@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface):
req_to_new_block_ids[req.request_id]) req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs for req in scheduled_new_reqs
] ]
resumed_reqs_data = [ cached_reqs_data = self._make_cached_request_data(
self._make_cached_request_data( scheduled_running_reqs,
req, scheduled_resumed_reqs,
num_scheduled_tokens[req.request_id], num_scheduled_tokens,
len(scheduled_spec_decode_tokens.get(req.request_id, ())), scheduled_spec_decode_tokens,
req_to_new_block_ids[req.request_id], req_to_new_block_ids,
resumed_from_preemption=True, )
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
...@@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface): ...@@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface):
def _make_cached_request_data( def _make_cached_request_data(
self, self,
request: Request, running_reqs: list[Request],
num_scheduled_tokens: int, resumed_reqs: list[Request],
num_scheduled_spec_tokens: int, num_scheduled_tokens: dict[str, int],
new_block_ids: tuple[list[int], ...], spec_decode_tokens: dict[str, list[int]],
resumed_from_preemption: bool, req_to_new_block_ids: dict[str, tuple[list[int], ...]],
) -> CachedRequestData: ) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating req_ids: list[str] = []
# them at each scheduling step. new_token_ids: list[list[int]] = []
num_computed_tokens = request.num_computed_tokens new_block_ids: list[tuple[list[int], ...]] = []
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens num_computed_tokens: list[int] = []
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens] for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
req_data_queue = self._cached_reqs_data.get(request.request_id) req_ids.append(req_id)
if req_data_queue: num_tokens = (num_scheduled_tokens[req_id] -
req_data = req_data_queue.popleft() len(spec_decode_tokens.get(req_id, ())))
req_data.resumed_from_preemption = resumed_from_preemption token_ids = req.all_token_ids[req.num_computed_tokens:req.
req_data.new_token_ids = new_token_ids num_computed_tokens + num_tokens]
req_data.new_block_ids = new_block_ids new_token_ids.append(token_ids)
req_data.num_computed_tokens = num_computed_tokens new_block_ids.append(req_to_new_block_ids[req_id])
else: num_computed_tokens.append(req.num_computed_tokens)
# No cached request data, or all cached request data has been # Because resumed_reqs is usually empty, it is more efficient to do
# used by the scheduled requests. # in-place appending so that we don't need to allocate a new list.
req_data = CachedRequestData.from_request(request, resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption, resumed_from_preemption += [True] * len(resumed_reqs)
new_token_ids,
new_block_ids) return CachedRequestData(
return req_data req_ids=req_ids,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)
def _try_schedule_encoder_inputs( def _try_schedule_encoder_inputs(
self, self,
...@@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface): ...@@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface):
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
self.running = new_running
# KV Connector: update state for finished KV Transfers. # KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output) self._update_from_kv_xfer_finished(model_runner_output)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
# to _cached_reqs_data will cause a memory leak.
if req_data.req_id not in self.finished_req_ids:
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running
# Create EngineCoreOutputs for all clients that have requests with # Create EngineCoreOutputs for all clients that have requests with
# outputs in this step. # outputs in this step.
engine_core_outputs = { engine_core_outputs = {
...@@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface): ...@@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface):
self._free_request(request) self._free_request(request)
def _free_request(self, request: Request) -> Optional[dict[str, Any]]: def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished() assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request) delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request) self.encoder_cache_manager.free(request)
request_id = request.request_id request_id = request.request_id
self._cached_reqs_data.pop(request_id, None)
self.finished_req_ids.add(request_id) self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None: if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id) self.finished_req_ids_dict[request.client_index].add(request_id)
...@@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface): ...@@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface):
def _free_blocks(self, request: Request): def _free_blocks(self, request: Request):
assert request.is_finished() assert request.is_finished()
assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request) self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request) self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id] del self.requests[request.request_id]
......
...@@ -470,34 +470,36 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -470,34 +470,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs: req_data = scheduler_output.scheduled_cached_reqs
req_id = req_data.req_id for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_token_ids = req_data.new_token_ids[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states. # Update the cached states.
num_computed_tokens = req_data.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any). # Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens. # This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens + num_new_tokens = (num_computed_tokens + len(new_token_ids) -
len(req_data.new_token_ids) -
req_state.num_tokens) req_state.num_tokens)
if num_new_tokens == 1: if num_new_tokens == 1:
# Avoid slicing list in most common case. # Avoid slicing list in most common case.
req_state.output_token_ids.append(req_data.new_token_ids[-1]) req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0: elif num_new_tokens > 0:
req_state.output_token_ids.extend( req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:]) new_token_ids[-num_new_tokens:])
# Update the block IDs. # Update the block IDs.
if not req_data.resumed_from_preemption: if not resumed_from_preemption:
# Append the new blocks to the existing block IDs. # Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip(req_state.block_ids, for block_ids, new_ids in zip(req_state.block_ids,
req_data.new_block_ids): new_block_ids):
block_ids.extend(new_block_ids) block_ids.extend(new_ids)
else: else:
# The request is resumed from preemption. # The request is resumed from preemption.
# Replace the existing block IDs with the new ones. # Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id) req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None: if req_index is None:
...@@ -510,14 +512,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -510,14 +512,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids, self.input_batch.block_table.append_row(new_block_ids, req_index)
req_index)
# Add new_token_ids to token_ids_cpu. # Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids) end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[
req_index, req_index, start_token_index:end_token_index] = new_token_ids
start_token_index:end_token_index] = req_data.new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
......
...@@ -418,21 +418,24 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -418,21 +418,24 @@ class TPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs: req_data = scheduler_output.scheduled_cached_reqs
req_id = req_data.req_id for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
if not req_data.resumed_from_preemption: if not resumed_from_preemption:
# Append the new blocks to the existing block IDs. # Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip(req_state.block_ids, for block_ids, new_ids in zip(req_state.block_ids,
req_data.new_block_ids): new_block_ids):
block_ids.extend(new_block_ids) block_ids.extend(new_ids)
else: else:
# The request is resumed from preemption. # The request is resumed from preemption.
# Replace the existing block IDs with the new ones. # Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id) req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None: if req_index is None:
...@@ -444,9 +447,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -444,9 +447,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids, self.input_batch.block_table.append_row(new_block_ids, req_index)
req_index)
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment