Commit 570c2c5b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.15.1-dev-fast-token-id-copy' into 'v0.15.1-dev'

perf(v1): 增加可选的快速 token-id 拷贝路径

See merge request dcutoolkit/deeplearing/vllm!440
parents 35006c0f d3a95d54
......@@ -293,6 +293,7 @@ if TYPE_CHECKING:
VLLM_W8A8_BACKEND: int = 3
VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
def get_default_cache_root():
......@@ -1847,6 +1848,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
#If set to 1/True, enable the V1 fast token-id copy path in InputBatch.
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, cast
import numpy as np
......@@ -47,6 +47,12 @@ class CachedRequestState:
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
_prompt_token_ids_np: np.ndarray | None = field(
default=None,
init=False,
repr=False,
compare=False,
)
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
......@@ -332,15 +338,41 @@ class InputBatch:
)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
else:
prompt_token_ids_np = request._prompt_token_ids_np
rebuild_prompt_cache = True
if prompt_token_ids_np is not None:
rebuild_prompt_cache = (
prompt_token_ids_np.dtype != np.int32
or prompt_token_ids_np.size != num_prompt_tokens
)
if rebuild_prompt_cache:
prompt_token_ids_np = np.asarray(request.prompt_token_ids, dtype=np.int32)
request._prompt_token_ids_np = prompt_token_ids_np
np.copyto(
self.token_ids_cpu[req_index, :num_prompt_tokens],
prompt_token_ids_np,
casting="no",
)
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
else:
output_token_ids_np = np.asarray(request.output_token_ids, dtype=np.int32)
end_idx = start_idx + output_token_ids_np.size
np.copyto(
self.token_ids_cpu[req_index, start_idx:end_idx],
output_token_ids_np,
casting="no",
)
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment