Unverified Commit 2a4dbe24 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[BugFix] Wait for compute before offloading KV to CPU (#31341)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 8020a604
...@@ -172,6 +172,7 @@ class RequestRunner: ...@@ -172,6 +172,7 @@ class RequestRunner:
self.pending_loads_count: int = 0 self.pending_loads_count: int = 0
self.pending_stores_count: int = 0 self.pending_stores_count: int = 0
self.unsubmitted_stores_count = 0
self.completed_loads: list[TransferSummary] = [] self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = [] self.completed_stores: list[TransferSummary] = []
...@@ -279,7 +280,9 @@ class RequestRunner: ...@@ -279,7 +280,9 @@ class RequestRunner:
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load) self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
self.pending_stores_count += self.unsubmitted_stores_count
self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store)
self.worker_connector.bind_connector_metadata(kv_connector_metadata) self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx) self.worker_connector.start_load_kv(self._dummy_ctx)
...@@ -414,10 +417,13 @@ def test_offloading_connector(request_runner): ...@@ -414,10 +417,13 @@ def test_offloading_connector(request_runner):
runner.manager.prepare_store.side_effect = ( runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(list(block_hashes)[1:2]) lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
) )
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5)) runner.run(decoded_tokens=[0])
# add block missing 1 token -> no offload # add block missing 1 token -> no offload
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1)) runner.run(
decoded_tokens=[0] * (offloaded_block_size - 1),
expected_stored_gpu_block_indexes=(3, 4, 5),
)
runner.manager.prepare_store.assert_not_called() runner.manager.prepare_store.assert_not_called()
# +1 token -> single block, fail prepare_store # +1 token -> single block, fail prepare_store
...@@ -435,23 +441,20 @@ def test_offloading_connector(request_runner): ...@@ -435,23 +441,20 @@ def test_offloading_connector(request_runner):
runner.manager.prepare_store.side_effect = ( runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes) lambda block_hashes: generate_store_output(block_hashes)
) )
runner.run( runner.run(decoded_tokens=[0] * offloaded_block_size)
decoded_tokens=[0] * offloaded_block_size,
expected_stored_gpu_block_indexes=(15, 16, 17),
)
runner.manager.touch.assert_called() runner.manager.touch.assert_called()
block_hashes1 = list(runner.manager.touch.call_args.args[0]) block_hashes1 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes1) == 6 assert len(block_hashes1) == 6
# terminate request # terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID]) runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(15, 16, 17),
)
# create a new request differing only on the last token # create a new request differing only on the last token
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1]) runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
runner.run( runner.run(decoded_tokens=[0])
decoded_tokens=[0],
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
)
runner.manager.touch.assert_called() runner.manager.touch.assert_called()
block_hashes2 = list(runner.manager.touch.call_args.args[0]) block_hashes2 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes2) == 6 assert len(block_hashes2) == 6
...@@ -461,7 +464,10 @@ def test_offloading_connector(request_runner): ...@@ -461,7 +464,10 @@ def test_offloading_connector(request_runner):
assert block_hashes1[5] != block_hashes2[5] assert block_hashes1[5] != block_hashes2[5]
# terminate request # terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID]) runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
)
# full_block_tokens - num_computed_tokens < offloaded_block_size # full_block_tokens - num_computed_tokens < offloaded_block_size
runner.new_request( runner.new_request(
......
...@@ -78,7 +78,7 @@ class OffloadingConnector(KVConnectorBase_V1): ...@@ -78,7 +78,7 @@ class OffloadingConnector(KVConnectorBase_V1):
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata) self.connector_worker.start_kv_transfers(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
pass pass
...@@ -95,7 +95,7 @@ class OffloadingConnector(KVConnectorBase_V1): ...@@ -95,7 +95,7 @@ class OffloadingConnector(KVConnectorBase_V1):
def wait_for_save(self): def wait_for_save(self):
assert self.connector_worker is not None assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.start_store_kv(self._connector_metadata) self.connector_worker.prepare_store_kv(self._connector_metadata)
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
assert self.connector_worker is not None assert self.connector_worker is not None
...@@ -427,6 +427,8 @@ class OffloadingConnectorWorker: ...@@ -427,6 +427,8 @@ class OffloadingConnectorWorker:
self._load_job: dict[ReqId, int] = {} self._load_job: dict[ReqId, int] = {}
# req_id -> set(active job IDs) # req_id -> set(active job IDs)
self._store_jobs = defaultdict[ReqId, set[int]](set) self._store_jobs = defaultdict[ReqId, set[int]](set)
# list of store jobs pending submission (job_id, transfer_spec)
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
self._finished_reqs_waiting_for_store: set[ReqId] = set() self._finished_reqs_waiting_for_store: set[ReqId] = set()
...@@ -464,20 +466,29 @@ class OffloadingConnectorWorker: ...@@ -464,20 +466,29 @@ class OffloadingConnectorWorker:
attn_backends = {cross_layer_name: attn_backend} attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends) self._register_handlers(kv_caches, attn_backends)
def start_load_kv(self, metadata: OffloadingConnectorMetadata): def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id, transfer_spec in metadata.reqs_to_load.items(): for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id() job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, False) self._jobs[job_id] = (req_id, False)
assert req_id not in self._load_job assert req_id not in self._load_job
self._load_job[req_id] = job_id self._load_job[req_id] = job_id
assert self.worker.transfer_async(job_id, transfer_spec) success = self.worker.transfer_async(job_id, transfer_spec)
assert success
def start_store_kv(self, metadata: OffloadingConnectorMetadata): def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_store.items(): for req_id, transfer_spec in metadata.reqs_to_store.items():
job_id = self._generate_job_id() job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, True) self._jobs[job_id] = (req_id, True)
self._store_jobs[req_id].add(job_id) self._store_jobs[req_id].add(job_id)
assert self.worker.transfer_async(job_id, transfer_spec) # NOTE(orozery): defer the store to the beginning of the next engine step,
# so that offloading starts AFTER transfers related to token sampling,
# thereby avoiding delays to token generation due to offloading.
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
""" """
......
...@@ -68,7 +68,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -68,7 +68,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
kv_dim_before_num_blocks: list[bool], kv_dim_before_num_blocks: list[bool],
src_block_size_factor: int, src_block_size_factor: int,
dst_block_size_factor: int, dst_block_size_factor: int,
priority: int,
): ):
""" """
Initialize a SingleDirectionOffloadingHandler. Initialize a SingleDirectionOffloadingHandler.
...@@ -85,8 +84,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -85,8 +84,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
per KV block in a source tensor. per KV block in a source tensor.
dst_block_size_factor: The number of kernel blocks dst_block_size_factor: The number of kernel blocks
per KV block in a destination tensor. per KV block in a destination tensor.
priority: The priority of the backing CUDA streams.
Lower numbers indicate higher priority.
""" """
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks) assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
...@@ -95,7 +92,9 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -95,7 +92,9 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
self.src_block_size_factor: int = src_block_size_factor self.src_block_size_factor: int = src_block_size_factor
self.dst_block_size_factor: int = dst_block_size_factor self.dst_block_size_factor: int = dst_block_size_factor
self.priority = priority
assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
# queue of transfers (job_id, stream, event) # queue of transfers (job_id, stream, event)
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque() self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
...@@ -130,12 +129,12 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -130,12 +129,12 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1]) expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
src_to_dst_tensor = torch.from_numpy(src_to_dst) src_to_dst_tensor = torch.from_numpy(src_to_dst)
stream = ( stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
self._stream_pool.pop()
if self._stream_pool
else torch.cuda.Stream(priority=self.priority)
)
event = self._event_pool.pop() if self._event_pool else torch.Event() event = self._event_pool.pop() if self._event_pool else torch.Event()
if self.gpu_to_cpu:
# wait for model computation to finish before offloading
stream.wait_stream(torch.cuda.current_stream())
if self._transfers: if self._transfers:
_, _, last_event = self._transfers[-1] _, _, last_event = self._transfers[-1]
# assure job will start only after the previous one completes # assure job will start only after the previous one completes
...@@ -267,7 +266,6 @@ class CpuGpuOffloadingHandlers: ...@@ -267,7 +266,6 @@ class CpuGpuOffloadingHandlers:
kv_dim_before_num_blocks=kv_dim_before_num_blocks, kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=gpu_block_size_factor, src_block_size_factor=gpu_block_size_factor,
dst_block_size_factor=cpu_block_size_factor, dst_block_size_factor=cpu_block_size_factor,
priority=1,
) )
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler( self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
...@@ -276,5 +274,4 @@ class CpuGpuOffloadingHandlers: ...@@ -276,5 +274,4 @@ class CpuGpuOffloadingHandlers:
kv_dim_before_num_blocks=kv_dim_before_num_blocks, kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=cpu_block_size_factor, src_block_size_factor=cpu_block_size_factor,
dst_block_size_factor=gpu_block_size_factor, dst_block_size_factor=gpu_block_size_factor,
priority=-1,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment