Unverified Commit c34ba6b9 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize compute maxsim using batched version, 3.2% E2E throughput improvement (#36710)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 24062b70
...@@ -4,13 +4,10 @@ ...@@ -4,13 +4,10 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
from vllm.entrypoints.pooling.score.utils import ( from vllm.entrypoints.pooling.score.utils import (
compute_maxsim_score,
compute_maxsim_scores,
get_score_prompt, get_score_prompt,
) )
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
...@@ -354,36 +351,3 @@ class TestGetScorePrompt: ...@@ -354,36 +351,3 @@ class TestGetScorePrompt:
assert_prompt_tokenization_consistent( assert_prompt_tokenization_consistent(
cross_encoder_tokenizer, full_prompt, engine_prompt cross_encoder_tokenizer, full_prompt, engine_prompt
) )
def test_compute_maxsim_scores_matches_reference_per_pair() -> None:
generator = torch.Generator()
generator.manual_seed(7)
shared_query = torch.randn(5, 8, generator=generator)
q_embs = [
shared_query, # 1:N style shared query
shared_query,
torch.randn(2, 8, generator=generator),
torch.randn(4, 8, generator=generator),
]
d_embs = [
torch.randn(6, 8, generator=generator),
torch.randn(3, 8, generator=generator),
torch.randn(5, 8, generator=generator),
torch.randn(7, 8, generator=generator),
]
batched_scores = compute_maxsim_scores(
q_embs,
d_embs,
max_batch_size=4,
max_score_matrix_elements=40, # batch shrinking path.
)
reference_scores = [
compute_maxsim_score(q, d).to("cpu") for q, d in zip(q_embs, d_embs)
]
assert len(batched_scores) == len(reference_scores)
for batched, reference in zip(batched_scores, reference_scores):
torch.testing.assert_close(batched, reference, rtol=1e-4, atol=1e-4)
...@@ -64,6 +64,47 @@ def test_postprocess_scores_and_releases_query_cache(): ...@@ -64,6 +64,47 @@ def test_postprocess_scores_and_releases_query_cache():
) )
def test_postprocess_scores_docs_in_batch():
runner = LateInteractionRunner()
query_key = "query-batch"
query_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
doc_emb_1 = torch.tensor([[1.0, 0.0], [0.5, 0.5]], dtype=torch.float32)
doc_emb_2 = torch.tensor([[0.0, 1.0], [0.3, 0.7], [1.0, 0.0]], dtype=torch.float32)
query_params = _make_pooling_params(
build_late_interaction_query_params(query_key=query_key, query_uses=2)
)
runner.postprocess_pooler_output(
raw_pooler_output=[query_emb],
pooling_params=[query_params],
req_ids=["query-req"],
finished_mask=[True],
)
doc_params = _make_pooling_params(
build_late_interaction_doc_params(query_key=query_key)
)
doc_output = runner.postprocess_pooler_output(
raw_pooler_output=[doc_emb_1, doc_emb_2],
pooling_params=[doc_params, doc_params],
req_ids=["doc-req-1", "doc-req-2"],
finished_mask=[True, True],
)
assert isinstance(doc_output, list)
assert doc_output[0] is not None
assert doc_output[1] is not None
assert torch.allclose(doc_output[0], compute_maxsim_score(query_emb, doc_emb_1))
assert torch.allclose(doc_output[1], compute_maxsim_score(query_emb, doc_emb_2))
with pytest.raises(ValueError, match="query cache miss"):
runner.postprocess_pooler_output(
raw_pooler_output=[doc_emb_1],
pooling_params=[doc_params],
req_ids=["doc-req-3"],
finished_mask=[True],
)
def test_finished_request_releases_unscored_doc_use(): def test_finished_request_releases_unscored_doc_use():
runner = LateInteractionRunner() runner = LateInteractionRunner()
query_key = "query-cancel" query_key = "query-cancel"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence from collections.abc import Iterable
from typing import Any, TypeAlias, cast from typing import Any, TypeAlias, cast
import torch import torch
...@@ -25,7 +25,6 @@ from vllm.inputs.data import PromptType, TextPrompt ...@@ -25,7 +25,6 @@ from vllm.inputs.data import PromptType, TextPrompt
from vllm.model_executor.models.interfaces import supports_score_template from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.platforms import current_platform
from vllm.renderers.hf import safe_apply_chat_template from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -54,91 +53,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens ...@@ -54,91 +53,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum() return token_scores.amax(dim=-1).sum()
def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
return use_gpu_for_pooling_score and not current_platform.is_cpu()
def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 16,
max_score_matrix_elements: int = 16_000_000,
use_gpu_for_pooling_score: bool = False,
) -> list[torch.Tensor]:
"""Compute ColBERT MaxSim scores in padded mini-batches."""
if len(q_embs) != len(d_embs):
raise ValueError("q_embs and d_embs must have the same length")
num_pairs = len(q_embs)
if num_pairs == 0:
return []
for q_emb, d_emb in zip(q_embs, d_embs):
if q_emb.ndim != 2 or d_emb.ndim != 2:
raise ValueError("Each embedding tensor must be 2-D")
if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")
compute_device = torch.device(
current_platform.device_type
if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
else "cpu"
)
scores: list[torch.Tensor] = []
start = 0
while start < num_pairs:
end = min(start + max_batch_size, num_pairs)
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
# keep score matrix bounded to avoid oversized allocations.
while (
end - start > 1
and (end - start) * max_q * max_d > max_score_matrix_elements
):
end -= 1
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
batch_q = q_embs[start:end]
batch_d = d_embs[start:end]
batch_size = end - start
dim = int(batch_q[0].shape[1])
dtype = batch_q[0].dtype
q_batch = torch.zeros(
(batch_size, max_q, dim), dtype=dtype, device=compute_device
)
d_batch = torch.zeros(
(batch_size, max_d, dim), dtype=dtype, device=compute_device
)
q_mask = torch.zeros(
(batch_size, max_q), dtype=torch.bool, device=compute_device
)
d_mask = torch.zeros(
(batch_size, max_d), dtype=torch.bool, device=compute_device
)
# copy to padded tensors
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
q_len = int(q_emb.shape[0])
d_len = int(d_emb.shape[0])
q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
q_mask[i, :q_len] = True
d_mask[i, :d_len] = True
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
max_per_query = token_scores.amax(dim=-1)
max_per_query.masked_fill_(~q_mask, 0)
batch_scores = max_per_query.sum(dim=-1).to("cpu")
scores.extend(batch_scores.unbind(0))
start = end
return [cast(torch.Tensor, score) for score in scores]
class ScoreMultiModalParam(TypedDict, total=False): class ScoreMultiModalParam(TypedDict, total=False):
""" """
A specialized parameter type for scoring multimodal content A specialized parameter type for scoring multimodal content
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import zlib import zlib
from collections.abc import Sequence
import torch import torch
...@@ -62,3 +63,81 @@ def compute_maxsim_score( ...@@ -62,3 +63,81 @@ def compute_maxsim_score(
# compute in float32 for numerical stability # compute in float32 for numerical stability
token_scores = torch.matmul(q_emb.float(), d_emb.float().T) token_scores = torch.matmul(q_emb.float(), d_emb.float().T)
return token_scores.amax(dim=-1).sum() return token_scores.amax(dim=-1).sum()
def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor],
max_batch_size: int = 64,
max_score_matrix_elements: int = 64_000_000,
) -> list[torch.Tensor]:
"""Compute MaxSim for multiple query/doc pairs in mini-batches."""
if len(q_embs) != len(d_embs):
raise ValueError("q_embs and d_embs must have the same length")
num_pairs = len(q_embs)
if num_pairs == 0:
return []
if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")
if max_score_matrix_elements <= 0:
raise ValueError("max_score_matrix_elements must be greater than 0")
for q_emb, d_emb in zip(q_embs, d_embs):
if q_emb.ndim != 2 or d_emb.ndim != 2:
raise ValueError("Each embedding tensor must be 2-D")
if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim")
if q_emb.device != d_emb.device:
raise ValueError("Query and document embeddings must be on same device")
scores: list[torch.Tensor] = []
start = 0
while start < num_pairs:
end = min(start + max_batch_size, num_pairs)
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
# keep score matrix bounded to avoid oversized allocations.
while (
end - start > 1
and (end - start) * max_q * max_d > max_score_matrix_elements
):
end -= 1
max_q = max(int(x.shape[0]) for x in q_embs[start:end])
max_d = max(int(x.shape[0]) for x in d_embs[start:end])
batch_q = q_embs[start:end]
batch_d = d_embs[start:end]
batch_size = end - start
device = batch_q[0].device
dim = int(batch_q[0].shape[1])
q_batch = torch.zeros(
(batch_size, max_q, dim), dtype=torch.float32, device=device
)
d_batch = torch.zeros(
(batch_size, max_d, dim), dtype=torch.float32, device=device
)
q_mask = torch.zeros((batch_size, max_q), dtype=torch.bool, device=device)
d_mask = torch.zeros((batch_size, max_d), dtype=torch.bool, device=device)
# copy to padded tensors
for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
q_len = int(q_emb.shape[0])
d_len = int(d_emb.shape[0])
q_batch[i, :q_len] = q_emb.to(device=device, dtype=torch.float32)
d_batch[i, :d_len] = d_emb.to(device=device, dtype=torch.float32)
q_mask[i, :q_len] = True
d_mask[i, :d_len] = True
token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
max_per_query = token_scores.amax(dim=-1)
max_per_query.masked_fill_(~q_mask, 0.0)
batch_scores = max_per_query.sum(dim=-1)
scores.extend(batch_scores.unbind(0))
start = end
return scores
...@@ -9,7 +9,7 @@ from vllm.v1.outputs import PoolerOutput ...@@ -9,7 +9,7 @@ from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.late_interaction import ( from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY, LATE_INTERACTION_MODE_CACHE_QUERY,
LATE_INTERACTION_MODE_SCORE_DOC, LATE_INTERACTION_MODE_SCORE_DOC,
compute_maxsim_score, compute_maxsim_scores,
) )
...@@ -72,6 +72,11 @@ class LateInteractionRunner: ...@@ -72,6 +72,11 @@ class LateInteractionRunner:
return raw_pooler_output return raw_pooler_output
outputs: list[torch.Tensor | None] = list(raw_pooler_output) outputs: list[torch.Tensor | None] = list(raw_pooler_output)
score_indices: list[int] = []
score_req_ids: list[str] = []
score_query_keys: list[str] = []
score_queries: list[torch.Tensor] = []
score_docs: list[torch.Tensor] = []
for i, (req_id, output, params, finished) in enumerate( for i, (req_id, output, params, finished) in enumerate(
zip(req_ids, outputs, pooling_params, finished_mask) zip(req_ids, outputs, pooling_params, finished_mask)
): ):
...@@ -101,13 +106,24 @@ class LateInteractionRunner: ...@@ -101,13 +106,24 @@ class LateInteractionRunner:
"before their paired document requests." "before their paired document requests."
) )
outputs[i] = compute_maxsim_score(query_output, output) score_indices.append(i)
self._doc_query_keys.pop(req_id, None) score_req_ids.append(req_id)
self._release_query_use(query_key) score_query_keys.append(query_key)
score_queries.append(query_output)
score_docs.append(output)
continue continue
raise ValueError(f"Unsupported late-interaction mode: {mode!r}") raise ValueError(f"Unsupported late-interaction mode: {mode!r}")
if score_indices:
score_values = compute_maxsim_scores(score_queries, score_docs)
for i, req_id, query_key, score in zip(
score_indices, score_req_ids, score_query_keys, score_values
):
outputs[i] = score
self._doc_query_keys.pop(req_id, None)
self._release_query_use(query_key)
return outputs return outputs
def _release_query_use(self, query_key: str) -> None: def _release_query_use(self, query_key: str) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment