Unverified Commit c21d0039 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Fix maxsim cuda platform and add cli to control it (#35427)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 7d8bbe6f
...@@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace): ...@@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace):
num_api_servers: int = args.api_server_count num_api_servers: int = args.api_server_count
assert num_api_servers > 0 assert num_api_servers > 0
if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False):
# TODO(wentao): remove this once well tested
raise ValueError(
"--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now"
)
if num_api_servers > 1: if num_api_servers > 1:
setup_multiprocess_prometheus() setup_multiprocess_prometheus()
......
...@@ -278,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs): ...@@ -278,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs):
Enable offline FastAPI documentation for air-gapped environments. Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM. Uses vendored static assets bundled with vLLM.
""" """
use_gpu_for_pooling_score: bool = False
"""If set, run pooling score MaxSim on GPU in the API server process.
Can significantly improve late-interaction scoring performance.
https://github.com/vllm-project/vllm/pull/35330"""
@classmethod @classmethod
def _customize_cli_kwargs( def _customize_cli_kwargs(
......
...@@ -115,6 +115,7 @@ def init_pooling_state( ...@@ -115,6 +115,7 @@ def init_pooling_state(
request_logger=request_logger, request_logger=request_logger,
score_template=resolved_chat_template, score_template=resolved_chat_template,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
) )
if any(t in supported_tasks for t in ("embed", "score", "token_embed")) if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None else None
......
...@@ -56,6 +56,7 @@ class ServingScores(OpenAIServing): ...@@ -56,6 +56,7 @@ class ServingScores(OpenAIServing):
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
score_template: str | None = None, score_template: str | None = None,
log_error_stack: bool = False, log_error_stack: bool = False,
use_gpu_for_pooling_score: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
...@@ -64,6 +65,7 @@ class ServingScores(OpenAIServing): ...@@ -64,6 +65,7 @@ class ServingScores(OpenAIServing):
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
self.score_template = score_template self.score_template = score_template
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
...@@ -314,6 +316,7 @@ class ServingScores(OpenAIServing): ...@@ -314,6 +316,7 @@ class ServingScores(OpenAIServing):
maxsim_scores = compute_maxsim_scores( maxsim_scores = compute_maxsim_scores(
[emb.outputs.data for emb in emb_data_1], [emb.outputs.data for emb in emb_data_1],
[emb.outputs.data for emb in emb_data_2], [emb.outputs.data for emb in emb_data_2],
use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
) )
scores: list[PoolingRequestOutput] = [] scores: list[PoolingRequestOutput] = []
......
...@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt ...@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt
from vllm.model_executor.models.interfaces import supports_score_template from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.platforms import current_platform
from vllm.renderers.hf import safe_apply_chat_template from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -53,11 +54,16 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens ...@@ -53,11 +54,16 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum() return token_scores.amax(dim=-1).sum()
def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
return use_gpu_for_pooling_score and not current_platform.is_cpu()
def compute_maxsim_scores( def compute_maxsim_scores(
q_embs: Sequence[torch.Tensor], q_embs: Sequence[torch.Tensor],
d_embs: Sequence[torch.Tensor], d_embs: Sequence[torch.Tensor],
max_batch_size: int = 16, max_batch_size: int = 16,
max_score_matrix_elements: int = 16_000_000, max_score_matrix_elements: int = 16_000_000,
use_gpu_for_pooling_score: bool = False,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
"""Compute ColBERT MaxSim scores in padded mini-batches.""" """Compute ColBERT MaxSim scores in padded mini-batches."""
if len(q_embs) != len(d_embs): if len(q_embs) != len(d_embs):
...@@ -73,7 +79,11 @@ def compute_maxsim_scores( ...@@ -73,7 +79,11 @@ def compute_maxsim_scores(
if q_emb.shape[1] != d_emb.shape[1]: if q_emb.shape[1] != d_emb.shape[1]:
raise ValueError("Query and document embeddings must have same dim") raise ValueError("Query and document embeddings must have same dim")
compute_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") compute_device = torch.device(
current_platform.device_type
if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
else "cpu"
)
scores: list[torch.Tensor] = [] scores: list[torch.Tensor] = []
start = 0 start = 0
while start < num_pairs: while start < num_pairs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment