Commit b12cca5e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev_tc_opt' into 'v0.11.0-dev'

perf: 加速 v1 InputBatch.add_request 的 token_ids 拷贝

See merge request dcutoolkit/deeplearing/vllm!346
parents c2ef7fdd 8da572a9
...@@ -242,7 +242,7 @@ if TYPE_CHECKING: ...@@ -242,7 +242,7 @@ if TYPE_CHECKING:
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1675,6 +1675,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1675,6 +1675,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE": "VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use fast token id copy
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch # Datastructures defining a GPU input batch
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional, cast from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from typing_extensions import deprecated from typing_extensions import deprecated
from vllm import envs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -45,6 +46,11 @@ class CachedRequestState: ...@@ -45,6 +46,11 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None prompt_embeds: Optional[torch.Tensor] = None
# Lazily populated when `VLLM_V1_FAST_TOKEN_ID_COPY` is enabled.
_prompt_token_ids_np: Optional[np.ndarray] = field(default=None,
init=False,
repr=False,
compare=False)
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
...@@ -325,22 +331,59 @@ class InputBatch: ...@@ -325,22 +331,59 @@ class InputBatch:
self.req_id_to_index[req_id] = req_index self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids. # Copy the prompt token ids and output token ids.
# OPTIMIZATION: Use np.copyto with pre-converted NumPy arrays
# instead of slice assignment. This avoids the ~550 µs overhead
# of converting Python list to NumPy array each time.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds( num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds) request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None: if request.prompt_token_ids is not None:
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
self.token_ids_cpu[ self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids req_index, :num_prompt_tokens] = request.prompt_token_ids
else:
prompt_token_ids_np = getattr(request, "_prompt_token_ids_np",
None)
rebuild_prompt_cache = True
if prompt_token_ids_np is not None:
try:
rebuild_prompt_cache = (prompt_token_ids_np.dtype !=
np.int32
or prompt_token_ids_np.size !=
num_prompt_tokens)
except Exception:
rebuild_prompt_cache = True
if rebuild_prompt_cache:
prompt_token_ids_np = np.asarray(request.prompt_token_ids,
dtype=np.int32)
try:
request._prompt_token_ids_np = prompt_token_ids_np
except Exception:
pass
np.copyto(
self.token_ids_cpu[req_index, :num_prompt_tokens],
prompt_token_ids_np,
casting='no',
)
self.is_token_ids[req_index, :num_prompt_tokens] = True self.is_token_ids[req_index, :num_prompt_tokens] = True
else: else:
self.is_token_ids[req_index, :num_prompt_tokens] = False self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None: if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds self.req_prompt_embeds[req_index] = request.prompt_embeds
if not envs.VLLM_V1_FAST_TOKEN_ID_COPY:
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids start_idx:end_idx] = request.output_token_ids
else:
output_token_ids_np = np.asarray(request.output_token_ids,
dtype=np.int32)
end_idx = start_idx + output_token_ids_np.size
np.copyto(
self.token_ids_cpu[req_index, start_idx:end_idx],
output_token_ids_np,
casting='no',
)
num_spec_tokens = 0 num_spec_tokens = 0
if request.spec_token_ids != None: if request.spec_token_ids != None:
num_spec_tokens = len(request.spec_token_ids) num_spec_tokens = len(request.spec_token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment