"docs/source/getting_started/faq.md" did not exist on "cad5c0a6eda057eeece87a42fff49fef3e18a2ac"
Commit 8da572a9 authored by laibao's avatar laibao
Browse files

perf: 加速 v1 InputBatch.add_request 的 token_ids 拷贝

新增环境变量开关 VLLM_V1_FAST_TOKEN_ID_COPY(默认关闭)。开启后在 CachedRequestState 中缓存 prompt_token_ids 的 np.int32,并在 add_request 里用 np.copyto 写入 token_ids_cpu,避免长 prompt 场景反复 list->NumPy 转换开销(尤其是抢占/反复进出 batch 时)
parent 383f2ce8
...@@ -244,7 +244,7 @@ if TYPE_CHECKING: ...@@ -244,7 +244,7 @@ if TYPE_CHECKING:
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1685,6 +1685,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1685,6 +1685,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE": "VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use fast token id copy
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch # Datastructures defining a GPU input batch
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional, cast from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from typing_extensions import deprecated from typing_extensions import deprecated
from vllm import envs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -45,6 +46,11 @@ class CachedRequestState: ...@@ -45,6 +46,11 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None prompt_embeds: Optional[torch.Tensor] = None
# Lazily populated when `VLLM_V1_FAST_TOKEN_ID_COPY` is enabled.
_prompt_token_ids_np: Optional[np.ndarray] = field(default=None,
init=False,
repr=False,
compare=False)
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
...@@ -325,22 +331,59 @@ class InputBatch: ...@@ -325,22 +331,59 @@ class InputBatch:
self.req_id_to_index[req_id] = req_index self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids. # Copy the prompt token ids and output token ids.
# OPTIMIZATION: Use np.copyto with pre-converted NumPy arrays
# instead of slice assignment. This avoids the ~550 µs overhead
# of converting Python list to NumPy array each time.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds( num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds) request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None: if request.prompt_token_ids is not None:
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
self.token_ids_cpu[ self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids req_index, :num_prompt_tokens] = request.prompt_token_ids
else:
prompt_token_ids_np = getattr(request, "_prompt_token_ids_np",
None)
rebuild_prompt_cache = True
if prompt_token_ids_np is not None:
try:
rebuild_prompt_cache = (prompt_token_ids_np.dtype !=
np.int32
or prompt_token_ids_np.size !=
num_prompt_tokens)
except Exception:
rebuild_prompt_cache = True
if rebuild_prompt_cache:
prompt_token_ids_np = np.asarray(request.prompt_token_ids,
dtype=np.int32)
try:
request._prompt_token_ids_np = prompt_token_ids_np
except Exception:
pass
np.copyto(
self.token_ids_cpu[req_index, :num_prompt_tokens],
prompt_token_ids_np,
casting='no',
)
self.is_token_ids[req_index, :num_prompt_tokens] = True self.is_token_ids[req_index, :num_prompt_tokens] = True
else: else:
self.is_token_ids[req_index, :num_prompt_tokens] = False self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None: if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds self.req_prompt_embeds[req_index] = request.prompt_embeds
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids start_idx:end_idx] = request.output_token_ids
else:
output_token_ids_np = np.asarray(request.output_token_ids,
dtype=np.int32)
end_idx = start_idx + output_token_ids_np.size
np.copyto(
self.token_ids_cpu[req_index, start_idx:end_idx],
output_token_ids_np,
casting='no',
)
num_spec_tokens = 0 num_spec_tokens = 0
if request.spec_token_ids != None: if request.spec_token_ids != None:
num_spec_tokens = len(request.spec_token_ids) num_spec_tokens = len(request.spec_token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment