"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "714c12daddb5e86c07625b14d839e0d39e968e2c"
Commit b12cca5e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev_tc_opt' into 'v0.11.0-dev'

perf: 加速 v1 InputBatch.add_request 的 token_ids 拷贝

See merge request dcutoolkit/deeplearing/vllm!346
parents c2ef7fdd 8da572a9
......@@ -242,7 +242,7 @@ if TYPE_CHECKING:
VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1675,6 +1675,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
# vLLM will use fast token id copy
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......@@ -1787,4 +1791,4 @@ def compute_hash() -> str:
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
\ No newline at end of file
return hash_str
......@@ -2,13 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, cast
import numpy as np
import torch
from typing_extensions import deprecated
from vllm import envs
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.pooling_params import PoolingParams
......@@ -45,6 +46,11 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
# Lazily populated when `VLLM_V1_FAST_TOKEN_ID_COPY` is enabled.
_prompt_token_ids_np: Optional[np.ndarray] = field(default=None,
init=False,
repr=False,
compare=False)
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
......@@ -325,22 +331,59 @@ class InputBatch:
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
# OPTIMIZATION: Use np.copyto with pre-converted NumPy arrays
# instead of slice assignment. This avoids the ~550 µs overhead
# of converting Python list to NumPy array each time.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
else:
prompt_token_ids_np = getattr(request, "_prompt_token_ids_np",
None)
rebuild_prompt_cache = True
if prompt_token_ids_np is not None:
try:
rebuild_prompt_cache = (prompt_token_ids_np.dtype !=
np.int32
or prompt_token_ids_np.size !=
num_prompt_tokens)
except Exception:
rebuild_prompt_cache = True
if rebuild_prompt_cache:
prompt_token_ids_np = np.asarray(request.prompt_token_ids,
dtype=np.int32)
try:
request._prompt_token_ids_np = prompt_token_ids_np
except Exception:
pass
np.copyto(
self.token_ids_cpu[req_index, :num_prompt_tokens],
prompt_token_ids_np,
casting='no',
)
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
else:
output_token_ids_np = np.asarray(request.output_token_ids,
dtype=np.int32)
end_idx = start_idx + output_token_ids_np.size
np.copyto(
self.token_ids_cpu[req_index, start_idx:end_idx],
output_token_ids_np,
casting='no',
)
num_spec_tokens = 0
if request.spec_token_ids != None:
num_spec_tokens = len(request.spec_token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment