Commit 0c1fa562 authored by zhuwenwen's avatar zhuwenwen
Browse files

修改MLA prefill阶段出现的Device2Host拷贝同步现象

parent 3fb0bfca
......@@ -1316,8 +1316,11 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
has_context = prefill_metadata.context_lens_tensor is not None \
and prefill_metadata.context_lens_tensor.max() > 0
if envs.VLLM_HAS_CONTEXT_DEFAULT:
has_context = prefill_metadata.context_lens_tensor is not None \
and prefill_metadata.context_lens_tensor.max() > 0
else:
has_context = False
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
......
......@@ -9,23 +9,12 @@ from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_OPTEST_URLS_PORT: Optional[int] = None
VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
......@@ -122,6 +111,19 @@ if TYPE_CHECKING:
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
# add envs
VLLM_OPTEST_URLS_PORT: Optional[int] = None
VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False
def get_default_cache_root():
......@@ -228,16 +230,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
'VLLM_PORT':
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,
# used in optest environment to manually set the https port
'VLLM_OPTEST_URLS_PORT':
lambda: int(os.getenv('VLLM_OPTEST_URLS_PORT', '8000'))
if 'VLLM_OPTEST_URLS_PORT' in os.environ else None,
# Path to the optest models.
# If set, will load models from local path instead of Hugging Face Hub.
'VLLM_OPTEST_MODELS_PATH':
lambda: os.getenv('VLLM_OPTEST_MODELS_PATH', "") or os.getenv("OPTEST_MODELS_PATH", ""),
# path used for ipc when the frontend api server is running in
# multi-processing mode to communicate with the backend engine process.
......@@ -272,43 +264,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control if vllm should use triton prefix flash attention
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")),
# flag to control vllm to use optimized tc paged attn kernels
"VLLM_USE_TC_PAGED_ATTN":
lambda: (os.environ.get("VLLM_USE_TC_PAGED_ATTN", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_SPEC_DECODE_EAGER":
lambda: bool(int(os.getenv("VLLM_SPEC_DECODE_EAGER", "0"))),
# flag to control vllm to use optimized kernels
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE":
lambda: bool(int(os.environ.get("VLLM_PCIE_USE_CUSTOM_ALLREDUCE", "0"))),
# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION":
lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
......@@ -794,6 +754,53 @@ environment_variables: dict[str, Callable[[], Any]] = {
# limit will actually be zero-copy decoded.
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
# used in optest environment to manually set the https port
'VLLM_OPTEST_URLS_PORT':
lambda: int(os.getenv('VLLM_OPTEST_URLS_PORT', '8000'))
if 'VLLM_OPTEST_URLS_PORT' in os.environ else None,
# Path to the optest models.
# If set, will load models from local path instead of Hugging Face Hub.
'VLLM_OPTEST_MODELS_PATH':
lambda: os.getenv('VLLM_OPTEST_MODELS_PATH', "") or os.getenv("OPTEST_MODELS_PATH", ""),
# flag to control if vllm should use triton prefix flash attention
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")),
# flag to control vllm to use optimized tc paged attn kernels
"VLLM_USE_TC_PAGED_ATTN":
lambda: (os.environ.get("VLLM_USE_TC_PAGED_ATTN", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_SPEC_DECODE_EAGER":
lambda: bool(int(os.getenv("VLLM_SPEC_DECODE_EAGER", "0"))),
# flag to control vllm to use optimized kernels
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE":
lambda: bool(int(os.environ.get("VLLM_PCIE_USE_CUSTOM_ALLREDUCE", "0"))),
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# If set, vLLM can avoid Device2Host copy during MLA prefill phase
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.environ.get("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
}
# end-env-vars-definition
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment