"vllm/vscode:/vscode.git/clone" did not exist on "22cf679aadca99311cfb5a9f894039e464e366aa"
Unverified Commit 00910171 authored by Philip Ottesen's avatar Philip Ottesen Committed by GitHub
Browse files

fix(worker): optimize swap_states to copy only active token prefixes (#34733)


Signed-off-by: default avatarPhilip Ottesen <phiott256@gmail.com>
parent 0d81a1fe
......@@ -529,6 +529,12 @@ class InputBatch:
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
# Only swap the active token prefix for each request. Copying full
# max_model_len rows is expensive and unnecessary during reordering.
i1_active_token_count = self._get_active_token_count(i1)
i2_active_token_count = self._get_active_token_count(i2)
max_active_token_count = max(i1_active_token_count, i2_active_token_count)
self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] = (
self.req_output_token_ids[i2],
......@@ -560,12 +566,15 @@ class InputBatch:
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
tmp_token_ids = self.token_ids_cpu[i1, :max_active_token_count].copy()
self.token_ids_cpu[i1, :max_active_token_count] = self.token_ids_cpu[
i2, :max_active_token_count
]
self.token_ids_cpu[i2, :max_active_token_count] = tmp_token_ids
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
self.is_token_ids[[i1, i2], :max_active_token_count] = self.is_token_ids[
[i2, i1], :max_active_token_count
]
# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
......@@ -629,6 +638,11 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[i1],
)
def _get_active_token_count(self, req_index: int) -> int:
return int(self.num_tokens_no_spec[req_index]) + len(
self.spec_token_ids[req_index]
)
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
......@@ -678,9 +692,7 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens_no_spec[last_req_index] + len(
self.spec_token_ids[last_req_index]
)
num_tokens = self._get_active_token_count(last_req_index)
(self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
self.spec_token_ids[empty_index],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment