Commit ef037256 authored by zhuwenwen's avatar zhuwenwen
Browse files

Add VLLM_USE_PA_PRINT_PARAM flag to print pa size

parent bd93e661
...@@ -160,7 +160,7 @@ set(VLLM_EXT_SRC ...@@ -160,7 +160,7 @@ set(VLLM_EXT_SRC
"csrc/attention/attention_kernels_opt.cu" "csrc/attention/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu" "csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
#"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu" # "csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
......
...@@ -123,6 +123,11 @@ class PagedAttention: ...@@ -123,6 +123,11 @@ class PagedAttention:
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.USE_VLLM_OPT_OP: if envs.USE_VLLM_OPT_OP:
ops.paged_attention_v1_opt( ops.paged_attention_v1_opt(
output, output,
...@@ -179,6 +184,13 @@ class PagedAttention: ...@@ -179,6 +184,13 @@ class PagedAttention:
device=output.device, device=output.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V2 SIZE:")
print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.USE_VLLM_OPT_OP: if envs.USE_VLLM_OPT_OP:
ops.paged_attention_v2_opt( ops.paged_attention_v2_opt(
output, output,
......
...@@ -11,6 +11,7 @@ if TYPE_CHECKING: ...@@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False VLLM_USE_FLASH_ATTN_AUTO: bool = False
USE_VLLM_OPT_OP: bool = False USE_VLLM_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
...@@ -145,6 +146,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -145,6 +146,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("USE_VLLM_OPT_OP", "True").lower() in lambda: (os.environ.get("USE_VLLM_OPT_OP", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": "LOCAL_RANK":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment