"vscode:/vscode.git/clone" did not exist on "541c1852d37b9502fbc06253def70e901ca0c352"
Unverified Commit 66f927f2 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[Bugfix] Fix pooling non-determinism from pinned prompt_lens aliasing (#37775)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent e78bc742
...@@ -378,3 +378,65 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis ...@@ -378,3 +378,65 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
ref_input_batch.refresh_metadata() ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch) _compare_objs(input_batch, ref_input_batch)
def _construct_pooling_request(req_id_suffix: int):
from vllm.pooling_params import PoolingParams
prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10, MAX_PROMPT_SIZE))
]
return CachedRequestState(
req_id=f"pool_req_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=None,
pooling_params=PoolingParams(task="classify"),
mm_features=[],
block_ids=([],),
generator=None,
num_computed_tokens=0,
output_token_ids=[],
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_pooling_prompt_lens_not_aliased(device: str):
"""Verify that prompt_lens in PoolingMetadata does not share memory
with the internal num_prompt_tokens pinned buffer. Guards against possible
non-determinism in pooling metadata due to mutations to the internal buffer.
"""
batch_size = 4
input_batch = InputBatch(
max_num_reqs=batch_size * 2,
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
max_num_batched_tokens=batch_size * (MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS),
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=VOCAB_SIZE,
block_sizes=[16],
kernel_block_sizes=[16],
is_pooling_model=True,
)
reqs = []
# Add requests
for i in range(batch_size):
req = _construct_pooling_request(i)
input_batch.add_request(req)
reqs.append(req)
input_batch.refresh_metadata()
# prompt_lens must be a snapshot
metadata = input_batch.get_pooling_metadata()
prompt_lens_snapshot = metadata.prompt_lens.clone()
# Mutate the internal buffer (simulates next batch adding new requests)
input_batch.num_prompt_tokens_cpu_tensor.fill_(999)
# prompt_lens must be unaffected by the mutation
assert torch.equal(metadata.prompt_lens, prompt_lens_snapshot), (
"prompt_lens shares memory with internal pinned buffer; "
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
)
...@@ -892,7 +892,7 @@ class InputBatch: ...@@ -892,7 +892,7 @@ class InputBatch:
pooling_states = self.get_pooling_states() pooling_states = self.get_pooling_states()
return PoolingMetadata( return PoolingMetadata(
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(),
prompt_token_ids=self.sampling_metadata.prompt_token_ids, prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params, pooling_params=pooling_params,
pooling_states=pooling_states, pooling_states=pooling_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment