Commit 903a588f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev

parents 64e307c7 ba0cd35c
...@@ -189,6 +189,7 @@ if TYPE_CHECKING: ...@@ -189,6 +189,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSE_SILU_AND_MUL: bool = False VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_PP_DEBUG: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1238,6 +1239,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1238,6 +1239,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in (os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in
("true", "1")), ("true", "1")),
"VLLM_PP_DEBUG":
lambda:
(os.environ.get("VLLM_PP_DEBUG", "False").lower() in
("true", "1")),
# vllm will use fused rmsnorm + contiguous + rope(for dpsk-v3) + concat_and_cache_mla # vllm will use fused rmsnorm + contiguous + rope(for dpsk-v3) + concat_and_cache_mla
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT": "VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT":
......
...@@ -168,7 +168,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -168,7 +168,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if not envs.VLLM_USE_CAT_MLA: if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024: if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
...@@ -181,7 +181,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -181,7 +181,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
if not envs.VLLM_USE_CAT_MLA: if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
...@@ -275,6 +275,12 @@ class EngineCore: ...@@ -275,6 +275,12 @@ class EngineCore:
pass pass
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0: if scheduler_output.total_num_scheduled_tokens > 0:
if envs.VLLM_PP_DEBUG:
import sys,os
num_run_reqs = len(scheduler_output.scheduled_new_reqs) + scheduler_output.scheduled_cached_reqs.num_reqs
sys.stderr.write(f"[pid- {os.getpid()}]running requests in micro batch is:{num_run_reqs}, "
f"total_num_scheduled_tokens is {scheduler_output.total_num_scheduled_tokens}\n")
sys.stderr.flush()
future = self.model_executor.execute_model(scheduler_output) future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait( self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore (future, scheduler_output)) # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment