Commit ef037256 authored by zhuwenwen's avatar zhuwenwen
Browse files

Add VLLM_USE_PA_PRINT_PARAM flag to print pa size

parent bd93e661
......@@ -160,7 +160,7 @@ set(VLLM_EXT_SRC
"csrc/attention/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
#"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
......
......@@ -123,6 +123,11 @@ class PagedAttention:
if use_v1:
# Run PagedAttention V1.
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.USE_VLLM_OPT_OP:
ops.paged_attention_v1_opt(
output,
......@@ -179,6 +184,13 @@ class PagedAttention:
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V2 SIZE:")
print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.USE_VLLM_OPT_OP:
ops.paged_attention_v2_opt(
output,
......
......@@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
USE_VLLM_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
......@@ -144,6 +145,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"USE_VLLM_OPT_OP":
lambda: (os.environ.get("USE_VLLM_OPT_OP", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment