Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8a1e7a3d
Commit
8a1e7a3d
authored
Aug 13, 2025
by
zhuwenwen
Browse files
add VLLM_USE_PA_PRINT_PARAM to print v1 fa size
parent
74a444b5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
3 deletions
+67
-3
setup.py
setup.py
+2
-2
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+57
-0
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+3
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+5
-0
No files found.
setup.py
View file @
8a1e7a3d
...
@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if
sha
is
None
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
sha
=
get_sha
(
vllm_root
)
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
version
=
'das.opt1.
beta
.'
+
sha
[:
7
]
version
=
'das.opt1.
rc1
.'
+
sha
[:
7
]
else
:
else
:
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
version
=
'das.opt1.
beta
'
version
=
'das.opt1.
rc1
'
# dtk version
# dtk version
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
8a1e7a3d
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
AttentionMetadata
,
AttentionType
,
...
@@ -122,6 +123,7 @@ class FlashAttentionMetadata:
...
@@ -122,6 +123,7 @@ class FlashAttentionMetadata:
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
max_seq_len
:
int
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
# seq_lens_tensor: torch.Tensor
block_table
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
...
@@ -226,6 +228,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -226,6 +228,7 @@ class FlashAttentionMetadataBuilder(
max_seq_len
=
int
(
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
())
max_seq_len
=
int
(
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
())
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
# seq_lens_tensor = common_attn_metadata.seq_lens_tensor
block_table
=
self
.
block_table
block_table
=
self
.
block_table
block_table_tensor
=
block_table
.
get_device_tensor
()[:
num_reqs
]
block_table_tensor
=
block_table
.
get_device_tensor
()[:
num_reqs
]
...
@@ -388,6 +391,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -388,6 +391,7 @@ class FlashAttentionMetadataBuilder(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
# seq_lens_tensor=seq_lens_tensor,
block_table
=
block_table_tensor
,
block_table
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
use_cascade
=
use_cascade
,
...
@@ -590,6 +594,12 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -590,6 +594,12 @@ class FlashAttentionImpl(AttentionImpl):
num_splits
=
attn_metadata
.
max_num_splits
,
num_splits
=
attn_metadata
.
max_num_splits
,
)
)
else
:
else
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA SIZE:"
)
print
(
f
"q.shape =
{
query
[:
num_actual_tokens
].
unsqueeze
(
1
).
shape
}
, key_cache.shape =
{
key_cache
.
shape
}
, value_cache.shape =
{
value_cache
.
shape
}
"
)
print
(
f
"cu_seqlens_q.shape =
{
cu_seqlens_q
.
shape
}
, max_seqlen_q =
{
max_seqlen_q
}
, seqused_k.shape =
{
seqused_k
.
shape
}
, max_seqlen_k =
{
max_seqlen_k
}
"
)
print
(
f
"softmax_scale =
{
self
.
scale
:.
3
f
}
, alibi_slopes =
{
self
.
alibi_slopes
}
, window_size =
{
self
.
sliding_window
}
, block_tables.shape =
{
block_table
.
shape
}
, softcap =
{
self
.
logits_soft_cap
}
, scheduler_metadata =
{
scheduler_metadata
}
"
)
vllm_flash_attn_varlen_func
(
vllm_flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
k
=
key_cache
,
...
@@ -613,6 +623,53 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -613,6 +623,53 @@ class FlashAttentionImpl(AttentionImpl):
# num_splits=attn_metadata.max_num_splits,
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache
=
False
,
is_prefix_cache
=
False
,
)
)
# if num_actual_tokens > 1:
# vllm_flash_attn_varlen_func(
# q=query[:num_actual_tokens],
# k=key_cache,
# v=value_cache,
# out=output[:num_actual_tokens],
# cu_seqlens_q=cu_seqlens_q,
# max_seqlen_q=max_seqlen_q,
# seqused_k=seqused_k,
# max_seqlen_k=max_seqlen_k,
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
# # fa_version=self.vllm_flash_attn_version,
# # q_descale=layer._q_scale.expand(descale_shape),
# # k_descale=layer._k_scale.expand(descale_shape),
# # v_descale=layer._v_scale.expand(descale_shape),
# # num_splits=attn_metadata.max_num_splits,
# is_prefix_cache=False,
# )
# else:
# from flash_attn import vllm_flash_attn_with_kvcache
# if envs.VLLM_USE_PA_PRINT_PARAM:
# print("PA SIZE:")
# print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}")
# print(f"cache_seqlens.shape = {attn_metadata.seq_lens_tensor.shape}, block_tables.shape = {block_table.shape}")
# print(f"softmax_scale = {self.scale:.3f}, window_size = {self.sliding_window}, softcap = {self.logits_soft_cap}, alibi_slopes = {self.alibi_slopes}")
# output[:num_actual_tokens] = vllm_flash_attn_with_kvcache(
# q=query[:num_actual_tokens].unsqueeze(1),
# k_cache=key_cache,
# v_cache=value_cache,
# cache_seqlens=attn_metadata.seq_lens_tensor,
# softmax_scale=self.scale,
# causal=True,
# alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# block_table=block_table,
# softcap=self.logits_soft_cap,
# # k_scale=layer._k_scale.expand(descale_shape),
# # v_scale=layer._v_scale.expand(descale_shape),
# ).squeeze(1)
return
output
return
output
assert
not
use_local_attn
,
(
assert
not
use_local_attn
,
(
...
...
vllm/v1/attention/backends/utils.py
View file @
8a1e7a3d
...
@@ -4,7 +4,7 @@ import abc
...
@@ -4,7 +4,7 @@ import abc
import
functools
import
functools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -35,6 +35,8 @@ class CommonAttentionMetadata:
...
@@ -35,6 +35,8 @@ class CommonAttentionMetadata:
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
"""(batch_size,), the length of each request including both computed tokens
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
and newly scheduled tokens"""
# seq_lens_tensor: torch.Tensor
# """seq_lens stored as a tensor."""
num_reqs
:
int
num_reqs
:
int
"""Number of requests"""
"""Number of requests"""
num_actual_tokens
:
int
num_actual_tokens
:
int
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
8a1e7a3d
...
@@ -247,6 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -247,6 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
seq_lens
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
seq_lens
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
# self.seq_lens_tensor = torch.zeros_like(self.seq_lens)
self
.
slot_mapping
=
torch
.
zeros
(
self
.
max_num_tokens
,
self
.
slot_mapping
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -700,10 +701,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -700,10 +701,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc
=
self
.
query_start_loc
[:
num_reqs
+
1
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
# seq_lens_tensor = self.seq_lens_tensor[:num_reqs]
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
# seq_lens_tensor=seq_lens_tensor,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
...
@@ -2018,11 +2021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2018,11 +2021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
)
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
# seq_lens_tensor = self.seq_lens_tensor[:num_reqs]
num_speculative_tokens
=
0
if
self
.
speculative_config
is
None
else
self
.
speculative_config
.
num_lookahead_slots
num_speculative_tokens
=
0
if
self
.
speculative_config
is
None
else
self
.
speculative_config
.
num_lookahead_slots
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
# seq_lens_tensor=seq_lens_tensor,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
num_tokens
,
max_query_len
=
num_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment