Commit d3a95d54 authored by laibao's avatar laibao
Browse files

• perf(v1): 增加可选的快速 token-id 拷贝路径

  - 新增环境变量 `VLLM_V1_FAST_TOKEN_ID_COPY`(默认关闭)
  - 在 `CachedRequestState` 中缓存 int32 的 prompt token ids(numpy 数组)
  - 开启后在 `InputBatch` 中使用 `np.copyto` 拷贝 prompt/output token ids
parent 35006c0f
......@@ -293,6 +293,7 @@ if TYPE_CHECKING:
VLLM_W8A8_BACKEND: int = 3
VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
def get_default_cache_root():
......@@ -1847,6 +1848,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
#If set to 1/True, enable the V1 fast token-id copy path in InputBatch.
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, cast
import numpy as np
......@@ -47,6 +47,12 @@ class CachedRequestState:
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
_prompt_token_ids_np: np.ndarray | None = field(
default=None,
init=False,
repr=False,
compare=False,
)
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
......@@ -332,15 +338,41 @@ class InputBatch:
)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
else:
prompt_token_ids_np = request._prompt_token_ids_np
rebuild_prompt_cache = True
if prompt_token_ids_np is not None:
rebuild_prompt_cache = (
prompt_token_ids_np.dtype != np.int32
or prompt_token_ids_np.size != num_prompt_tokens
)
if rebuild_prompt_cache:
prompt_token_ids_np = np.asarray(request.prompt_token_ids, dtype=np.int32)
request._prompt_token_ids_np = prompt_token_ids_np
np.copyto(
self.token_ids_cpu[req_index, :num_prompt_tokens],
prompt_token_ids_np,
casting="no",
)
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
else:
output_token_ids_np = np.asarray(request.output_token_ids, dtype=np.int32)
end_idx = start_idx + output_token_ids_np.size
np.copyto(
self.token_ids_cpu[req_index, start_idx:end_idx],
output_token_ids_np,
casting="no",
)
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment