Commit 2f9e0bad authored by zhuwenwen's avatar zhuwenwen
Browse files

Add VLLM_USE_PA_PRINT_PARAM flag to print pa size

parent 1c18cce0
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
import vllm.envs as envs
if HAS_TRITON: if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
...@@ -128,6 +129,11 @@ class PagedAttention: ...@@ -128,6 +129,11 @@ class PagedAttention:
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
ops.paged_attention_v1( ops.paged_attention_v1(
output, output,
query, query,
...@@ -163,6 +169,13 @@ class PagedAttention: ...@@ -163,6 +169,13 @@ class PagedAttention:
device=output.device, device=output.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V2 SIZE:")
print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
ops.paged_attention_v2( ops.paged_attention_v2(
output, output,
exp_sums, exp_sums,
......
...@@ -12,6 +12,7 @@ if TYPE_CHECKING: ...@@ -12,6 +12,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False VLLM_USE_FLASH_ATTN_AUTO: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
...@@ -187,6 +188,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -187,6 +188,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# Internal flag to enable Dynamo graph capture # Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE": "VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment