Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef037256
Commit
ef037256
authored
Aug 21, 2024
by
zhuwenwen
Browse files
Add VLLM_USE_PA_PRINT_PARAM flag to print pa size
parent
bd93e661
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
1 deletion
+19
-1
CMakeLists.txt
CMakeLists.txt
+1
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+12
-0
vllm/envs.py
vllm/envs.py
+6
-0
No files found.
CMakeLists.txt
View file @
ef037256
...
@@ -160,7 +160,7 @@ set(VLLM_EXT_SRC
...
@@ -160,7 +160,7 @@ set(VLLM_EXT_SRC
"csrc/attention/attention_kernels_opt.cu"
"csrc/attention/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
#
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu"
# "csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/cuda_utils_kernels.cu"
...
...
vllm/attention/ops/paged_attn.py
View file @
ef037256
...
@@ -123,6 +123,11 @@ class PagedAttention:
...
@@ -123,6 +123,11 @@ class PagedAttention:
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V1 SIZE:"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
USE_VLLM_OPT_OP
:
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
paged_attention_v1_opt
(
ops
.
paged_attention_v1_opt
(
output
,
output
,
...
@@ -179,6 +184,13 @@ class PagedAttention:
...
@@ -179,6 +184,13 @@ class PagedAttention:
device
=
output
.
device
,
device
=
output
.
device
,
)
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V2 SIZE:"
)
print
(
f
"exp_sums.shape =
{
exp_sums
.
shape
}
, max_logits.shape =
{
max_logits
.
shape
}
, tmp_output.shape =
{
tmp_output
.
shape
}
"
)
print
(
f
"query.shape =
{
query
.
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"num_kv_heads =
{
num_kv_heads
}
, scale =
{
scale
:.
3
f
}
, block_tables.shape =
{
block_tables
.
shape
}
, seq_lens.shape =
{
seq_lens
.
shape
}
, block_size =
{
block_size
}
, max_seq_len =
{
max_seq_len
}
"
)
if
envs
.
USE_VLLM_OPT_OP
:
if
envs
.
USE_VLLM_OPT_OP
:
ops
.
paged_attention_v2_opt
(
ops
.
paged_attention_v2_opt
(
output
,
output
,
...
...
vllm/envs.py
View file @
ef037256
...
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
...
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
VLLM_USE_FLASH_ATTN_AUTO
:
bool
=
False
USE_VLLM_OPT_OP
:
bool
=
False
USE_VLLM_OPT_OP
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
LOCAL_RANK
:
int
=
0
LOCAL_RANK
:
int
=
0
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
CUDA_VISIBLE_DEVICES
:
Optional
[
str
]
=
None
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
VLLM_ENGINE_ITERATION_TIMEOUT_S
:
int
=
60
...
@@ -145,6 +146,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -145,6 +146,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"USE_VLLM_OPT_OP"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"USE_VLLM_OPT_OP"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PA_PRINT_PARAM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
# the GPU device id
# the GPU device id
"LOCAL_RANK"
:
"LOCAL_RANK"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment