Unverified Commit 995dea13 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Remove redundant device copies for CPU-only pooling token IDs, 48.9%...


[Perf] Remove redundant device copies for CPU-only pooling token IDs, 48.9% E2E throughput improvement (#38139)
Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 8c0b6267
...@@ -380,7 +380,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis ...@@ -380,7 +380,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
_compare_objs(input_batch, ref_input_batch) _compare_objs(input_batch, ref_input_batch)
def _construct_pooling_request(req_id_suffix: int): def _construct_pooling_request(req_id_suffix: int, pooling_params=None):
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
prompt_token_ids = [ prompt_token_ids = [
...@@ -391,7 +391,7 @@ def _construct_pooling_request(req_id_suffix: int): ...@@ -391,7 +391,7 @@ def _construct_pooling_request(req_id_suffix: int):
req_id=f"pool_req_{req_id_suffix}", req_id=f"pool_req_{req_id_suffix}",
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=None, sampling_params=None,
pooling_params=PoolingParams(task="classify"), pooling_params=pooling_params or PoolingParams(task="classify"),
mm_features=[], mm_features=[],
block_ids=([],), block_ids=([],),
generator=None, generator=None,
...@@ -440,3 +440,48 @@ def test_pooling_prompt_lens_not_aliased(device: str): ...@@ -440,3 +440,48 @@ def test_pooling_prompt_lens_not_aliased(device: str):
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. " "mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}" f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
) )
@pytest.mark.parametrize(
("pooling_params", "expect_device_prompt_token_ids", "expect_cpu_prompt_token_ids"),
[
({"task": "classify"}, False, False),
({"task": "classify", "requires_token_ids": True}, True, True),
],
)
def test_pooling_metadata_token_id_buffers(
pooling_params: dict[str, object],
expect_device_prompt_token_ids: bool,
expect_cpu_prompt_token_ids: bool,
):
from vllm.pooling_params import PoolingParams
input_batch = InputBatch(
max_num_reqs=1,
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
max_num_batched_tokens=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
device=torch.device("cpu"),
pin_memory=False,
vocab_size=VOCAB_SIZE,
block_sizes=[16],
kernel_block_sizes=[16],
is_pooling_model=True,
)
req = _construct_pooling_request(0, PoolingParams(**pooling_params))
input_batch.add_request(req)
input_batch.refresh_metadata()
metadata = input_batch.get_pooling_metadata()
if expect_device_prompt_token_ids:
assert input_batch.sampling_metadata.prompt_token_ids is not None
assert metadata.prompt_token_ids is not None
assert metadata.get_prompt_token_ids()[0].tolist() == req.prompt_token_ids
else:
assert input_batch.sampling_metadata.prompt_token_ids is None
assert metadata.prompt_token_ids is None
if expect_cpu_prompt_token_ids:
assert metadata.prompt_token_ids_cpu is not None
assert metadata.get_prompt_token_ids_cpu()[0].tolist() == req.prompt_token_ids
else:
assert metadata.prompt_token_ids_cpu is None
...@@ -18,7 +18,7 @@ ActivationFn = Callable[[_T], _T] ...@@ -18,7 +18,7 @@ ActivationFn = Callable[[_T], _T]
@dataclass(frozen=True) @dataclass(frozen=True)
class PoolingParamsUpdate: class PoolingParamsUpdate:
requires_token_ids: bool = False requires_token_ids: bool = False
"""Set this flag to enable `get_prompt_token_ids` for your pooler.""" """Set this flag to enable prompt token IDs for your pooler."""
def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate": def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate":
return PoolingParamsUpdate( return PoolingParamsUpdate(
......
...@@ -146,17 +146,19 @@ class BOSEOSFilter(Pooler): ...@@ -146,17 +146,19 @@ class BOSEOSFilter(Pooler):
) -> PoolerOutput: ) -> PoolerOutput:
pooled_outputs = self.pooler(hidden_states, pooling_metadata) pooled_outputs = self.pooler(hidden_states, pooling_metadata)
assert isinstance(pooled_outputs, list) assert isinstance(pooled_outputs, list)
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
for i, prompt_len in enumerate(pooling_metadata.prompt_lens): for i, (prompt_len, token_ids) in enumerate(
zip(pooling_metadata.prompt_lens, prompt_token_ids)
):
pooled_data = pooled_outputs[i] pooled_data = pooled_outputs[i]
assert ( assert (
isinstance(pooled_data, torch.Tensor) isinstance(pooled_data, torch.Tensor)
and pooled_data.shape[0] == prompt_len and pooled_data.shape[0] == prompt_len
) )
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len] if int(token_ids[0]) == self.bos_token_id:
if token_ids[0] == self.bos_token_id:
pooled_data = pooled_data[1:] pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id: if int(token_ids[-1]) == self.eos_token_id:
pooled_data = pooled_data[:-1] pooled_data = pooled_data[:-1]
pooled_outputs[i] = pooled_data.squeeze(-1) pooled_outputs[i] = pooled_data.squeeze(-1)
......
...@@ -638,25 +638,26 @@ class SPLADESparsePooler(Pooler): ...@@ -638,25 +638,26 @@ class SPLADESparsePooler(Pooler):
lens: list[int] = lens_tensor.tolist() lens: list[int] = lens_tensor.tolist()
B: int = len(lens) B: int = len(lens)
token_ids = pooling_metadata.prompt_token_ids prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
offset = 0 offset = 0
pooled_list: list[torch.Tensor] = [] pooled_list: list[torch.Tensor] = []
for i in range(B): for i in range(B):
L = int(lens[i]) L = int(lens[i])
hs = hidden_states[offset : offset + L] hs = hidden_states[offset : offset + L]
token_ids = prompt_token_ids[i]
start_idx = 0 start_idx = 0
end_idx = L end_idx = L
if self.remove_cls_sep and token_ids is not None: if self.remove_cls_sep:
if ( if (
self.cls_token_id is not None self.cls_token_id is not None
and token_ids[i, 0].item() == self.cls_token_id and int(token_ids[0]) == self.cls_token_id
): ):
start_idx = 1 start_idx = 1
if ( if (
self.sep_token_id is not None self.sep_token_id is not None
and token_ids[i, L - 1].item() == self.sep_token_id and int(token_ids[L - 1]) == self.sep_token_id
): ):
end_idx = max(start_idx, L - 1) end_idx = max(start_idx, L - 1)
......
...@@ -156,10 +156,11 @@ class GritLMMeanPool(SequencePoolingMethod): ...@@ -156,10 +156,11 @@ class GritLMMeanPool(SequencePoolingMethod):
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput: ) -> SequencePoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens prompt_lens = pooling_metadata.prompt_lens
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
instr_lens = torch.tensor( instr_lens = torch.tensor(
[ [
self._get_instruction_len(token_ids.cpu().numpy()) self._get_instruction_len(token_ids.numpy())
for token_ids in pooling_metadata.get_prompt_token_ids() for token_ids in prompt_token_ids
], ],
device="cpu", device="cpu",
) )
......
...@@ -50,7 +50,8 @@ class PoolingMetadata: ...@@ -50,7 +50,8 @@ class PoolingMetadata:
"""Tensors for pooling.""" """Tensors for pooling."""
prompt_lens: torch.Tensor # CPU Tensor prompt_lens: torch.Tensor # CPU Tensor
prompt_token_ids: torch.Tensor | None prompt_token_ids: torch.Tensor | None # Model-device tensor
prompt_token_ids_cpu: torch.Tensor | None # CPU tensor
pooling_params: list[PoolingParams] pooling_params: list[PoolingParams]
pooling_states: list[PoolingStates] pooling_states: list[PoolingStates]
pooling_cursor: PoolingCursor | None = None pooling_cursor: PoolingCursor | None = None
...@@ -73,6 +74,9 @@ class PoolingMetadata: ...@@ -73,6 +74,9 @@ class PoolingMetadata:
prompt_token_ids=None prompt_token_ids=None
if self.prompt_token_ids is None if self.prompt_token_ids is None
else self.prompt_token_ids[indices], else self.prompt_token_ids[indices],
prompt_token_ids_cpu=None
if self.prompt_token_ids_cpu is None
else self.prompt_token_ids_cpu[indices],
pooling_params=self.pooling_params[indices], pooling_params=self.pooling_params[indices],
pooling_states=self.pooling_states[indices], pooling_states=self.pooling_states[indices],
pooling_cursor=None pooling_cursor=None
...@@ -85,7 +89,13 @@ class PoolingMetadata: ...@@ -85,7 +89,13 @@ class PoolingMetadata:
assert prompt_token_ids is not None, ( assert prompt_token_ids is not None, (
"Please set `requires_token_ids=True` in `get_pooling_updates`" "Please set `requires_token_ids=True` in `get_pooling_updates`"
) )
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def get_prompt_token_ids_cpu(self) -> list[torch.Tensor]:
prompt_token_ids = self.prompt_token_ids_cpu
assert prompt_token_ids is not None, (
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def get_pooling_cursor(self) -> PoolingCursor: def get_pooling_cursor(self) -> PoolingCursor:
......
...@@ -833,8 +833,13 @@ class InputBatch: ...@@ -833,8 +833,13 @@ class InputBatch:
# step pooling during the sampling/pooling process. # step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which # Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied. # need penalties/step_pooler to be applied.
prompt_token_ids_cpu = (
self._make_prompt_token_ids_cpu_tensor() if needs_prompt_token_ids else None
)
prompt_token_ids = ( prompt_token_ids = (
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None prompt_token_ids_cpu.to(device=self.device, non_blocking=True)
if prompt_token_ids_cpu is not None
else None
) )
# Only set output_token_ids if required by the current requests' # Only set output_token_ids if required by the current requests'
...@@ -891,15 +896,19 @@ class InputBatch: ...@@ -891,15 +896,19 @@ class InputBatch:
def get_pooling_metadata(self) -> PoolingMetadata: def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params() pooling_params = self.get_pooling_params()
pooling_states = self.get_pooling_states() pooling_states = self.get_pooling_states()
prompt_token_ids_cpu = None
if any(p.requires_token_ids for p in pooling_params):
prompt_token_ids_cpu = self._make_prompt_token_ids_cpu_tensor()
return PoolingMetadata( return PoolingMetadata(
prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(), prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(),
prompt_token_ids=self.sampling_metadata.prompt_token_ids, prompt_token_ids=self.sampling_metadata.prompt_token_ids,
prompt_token_ids_cpu=prompt_token_ids_cpu,
pooling_params=pooling_params, pooling_params=pooling_params,
pooling_states=pooling_states, pooling_states=pooling_states,
) )
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_cpu_tensor(self) -> torch.Tensor:
num_reqs = self.num_reqs num_reqs = self.num_reqs
max_prompt_len = self.num_prompt_tokens[:num_reqs].max() max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty( prompt_token_ids_cpu_tensor = torch.empty(
...@@ -914,7 +923,7 @@ class InputBatch: ...@@ -914,7 +923,7 @@ class InputBatch:
# token_id of this value. # token_id of this value.
for i in range(num_reqs): for i in range(num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) return prompt_token_ids_cpu_tensor
def make_lora_inputs( def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
......
...@@ -5653,6 +5653,7 @@ class GPUModelRunner( ...@@ -5653,6 +5653,7 @@ class GPUModelRunner(
dummy_metadata = PoolingMetadata( dummy_metadata = PoolingMetadata(
prompt_lens=dummy_prompt_lens, prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids, prompt_token_ids=dummy_token_ids,
prompt_token_ids_cpu=dummy_token_ids.cpu(),
pooling_params=[dummy_pooling_params] * num_reqs, pooling_params=[dummy_pooling_params] * num_reqs,
pooling_states=[PoolingStates() for i in range(num_reqs)], pooling_states=[PoolingStates() for i in range(num_reqs)],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment