Commit fbe8587a authored by laibao's avatar laibao
Browse files

perf: 加速 v1 InputBatch.add_request 的 token_ids 拷贝

新增环境变量开关 VLLM_V1_FAST_TOKEN_ID_COPY(默认关闭)。开启后在 CachedRequestState 中缓存 prompt_token_ids 的 np.int32,并在 add_request 里用 np.copyto 写入 token_ids_cpu,避免长 prompt 场景反复 list->NumPy 转换开销(尤其是抢占/反复进出 batch 时)
parent 1e57506d
......@@ -202,6 +202,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1308,6 +1309,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ZERO_OVERHEAD_ENHANCE":
lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in
("true", "1")),
# vLLM will use fast token id copy
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -2,12 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, cast
import numpy as np
import torch
from vllm import envs
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
......@@ -44,6 +45,11 @@ class CachedRequestState:
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
# Lazily populated when `VLLM_V1_FAST_TOKEN_ID_COPY` is enabled.
_prompt_token_ids_np: Optional[np.ndarray] = field(default=None,
init=False,
repr=False,
compare=False)
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
......@@ -285,15 +291,50 @@ class InputBatch:
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
# OPTIMIZATION: Use np.copyto with pre-converted NumPy arrays
# instead of slice assignment. This avoids the ~550 µs overhead
# of converting Python list to NumPy array each time.
num_prompt_tokens = len(request.prompt_token_ids)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
else:
prompt_token_ids_np = getattr(request, "_prompt_token_ids_np", None)
rebuild_prompt_cache = True
if prompt_token_ids_np is not None:
try:
rebuild_prompt_cache = (prompt_token_ids_np.dtype != np.int32
or prompt_token_ids_np.size !=
num_prompt_tokens)
except Exception:
rebuild_prompt_cache = True
if rebuild_prompt_cache:
prompt_token_ids_np = np.asarray(request.prompt_token_ids,
dtype=np.int32)
try:
request._prompt_token_ids_np = prompt_token_ids_np
except Exception:
pass
np.copyto(
self.token_ids_cpu[req_index, :num_prompt_tokens],
prompt_token_ids_np,
casting='no',
)
start_idx = num_prompt_tokens
output_token_ids_np = np.asarray(request.output_token_ids,
dtype=np.int32)
end_idx = start_idx + output_token_ids_np.size
np.copyto(
self.token_ids_cpu[req_index, start_idx:end_idx],
output_token_ids_np,
casting='no',
)
num_spec_tokens = 0
if request.spec_token_ids != None:
num_spec_tokens = len(request.spec_token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment