Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e782e0a1
Unverified
Commit
e782e0a1
authored
Apr 26, 2025
by
Aaron Pham
Committed by
GitHub
Apr 26, 2025
Browse files
[Chore] added stubs for `vllm_flash_attn` during development mode (#17228)
Signed-off-by:
Aaron Pham
<
contact@aarnphm.xyz
>
parent
dc2ceca5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
269 additions
and
2 deletions
+269
-2
pyproject.toml
pyproject.toml
+2
-1
setup.py
setup.py
+0
-1
vllm/vllm_flash_attn/__init__.py
vllm/vllm_flash_attn/__init__.py
+22
-0
vllm/vllm_flash_attn/flash_attn_interface.pyi
vllm/vllm_flash_attn/flash_attn_interface.pyi
+245
-0
No files found.
pyproject.toml
View file @
e782e0a1
...
@@ -58,7 +58,8 @@ ignore_patterns = [
...
@@ -58,7 +58,8 @@ ignore_patterns = [
line-length
=
80
line-length
=
80
exclude
=
[
exclude
=
[
# External file, leaving license intact
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py"
"examples/other/fp8/quantizer/quantize.py"
,
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
]
[tool.ruff.lint.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
...
...
setup.py
View file @
e782e0a1
...
@@ -378,7 +378,6 @@ class repackage_wheel(build_ext):
...
@@ -378,7 +378,6 @@ class repackage_wheel(build_ext):
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so"
,
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so"
,
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so"
,
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so"
,
"vllm/vllm_flash_attn/flash_attn_interface.py"
,
"vllm/vllm_flash_attn/flash_attn_interface.py"
,
"vllm/vllm_flash_attn/__init__.py"
,
"vllm/cumem_allocator.abi3.so"
,
"vllm/cumem_allocator.abi3.so"
,
# "vllm/_version.py", # not available in nightly wheels yet
# "vllm/_version.py", # not available in nightly wheels yet
]
]
...
...
vllm/vllm_flash_attn/__init__.py
View file @
e782e0a1
# SPDX-License-Identifier: Apache-2.0
import
importlib.metadata
try
:
__version__
=
importlib
.
metadata
.
version
(
"vllm-flash-attn"
)
except
importlib
.
metadata
.
PackageNotFoundError
:
# in this case, vllm-flash-attn is built from installing vllm editable
__version__
=
"0.0.0.dev0"
from
.flash_attn_interface
import
(
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
get_scheduler_metadata
,
is_fa_version_supported
,
sparse_attn_func
,
sparse_attn_varlen_func
)
__all__
=
[
'flash_attn_varlen_func'
,
'flash_attn_with_kvcache'
,
'get_scheduler_metadata'
,
'sparse_attn_func'
,
'sparse_attn_varlen_func'
,
'is_fa_version_supported'
,
'fa_version_unsupported_reason'
]
vllm/vllm_flash_attn/flash_attn_interface.pyi
0 → 100644
View file @
e782e0a1
# ruff: ignore
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any, Literal, overload
import torch
def get_scheduler_metadata(
batch_size: int,
max_seqlen_q: int,
max_seqlen_k: int,
num_heads_q: int,
num_heads_kv: int,
headdim: int,
cache_seqlens: torch.Tensor,
qkv_dtype: torch.dtype = ...,
headdim_v: int | None = ...,
cu_seqlens_q: torch.Tensor | None = ...,
cu_seqlens_k_new: torch.Tensor | None = ...,
cache_leftpad: torch.Tensor | None = ...,
page_size: int = ...,
max_seqlen_k_new: int = ...,
causal: bool = ...,
window_size: tuple[int, int] = ...,
has_softcap: bool = ...,
num_splits: int = ...,
pack_gqa: Any | None = ...,
sm_margin: int = ...,
): ...
@overload
def flash_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
max_seqlen_q: int,
cu_seqlens_q: torch.Tensor | None,
max_seqlen_k: int,
cu_seqlens_k: torch.Tensor | None = ...,
seqused_k: Any | None = ...,
q_v: Any | None = ...,
dropout_p: float = ...,
causal: bool = ...,
window_size: list[int] | None = ...,
softmax_scale: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
block_table: Any | None = ...,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[int, int, int]: ...
@overload
def flash_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
max_seqlen_q: int,
cu_seqlens_q: torch.Tensor | None,
max_seqlen_k: int,
cu_seqlens_k: torch.Tensor | None = ...,
seqused_k: Any | None = ...,
q_v: Any | None = ...,
dropout_p: float = ...,
causal: bool = ...,
window_size: list[int] | None = ...,
softmax_scale: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
block_table: Any | None = ...,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def flash_attn_with_kvcache(
q: tuple[int, int, int, int],
k_cache: tuple[int, int, int, int],
v_cache: tuple[int, int, int, int],
k: tuple[int, int, int, int] | None = ...,
v: tuple[int, int, int, int] | None = ...,
rotary_cos: tuple[int, int] | None = ...,
rotary_sin: tuple[int, int] | None = ...,
cache_seqlens: int | torch.Tensor | None = None,
cache_batch_idx: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = ...,
block_table: torch.Tensor | None = ...,
softmax_scale: float = ...,
causal: bool = ...,
window_size: tuple[int, int] = ..., # -1 means infinite context window
softcap: float = ...,
rotary_interleaved: bool = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
num_splits: int = ...,
return_softmax_lse: Literal[False] = ...,
*,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[int, int, int, int]: ...
@overload
def flash_attn_with_kvcache(
q: tuple[int, int, int, int],
k_cache: tuple[int, int, int, int],
v_cache: tuple[int, int, int, int],
k: tuple[int, int, int, int] | None = ...,
v: tuple[int, int, int, int] | None = ...,
rotary_cos: tuple[int, int] | None = ...,
rotary_sin: tuple[int, int] | None = ...,
cache_seqlens: int | torch.Tensor | None = None,
cache_batch_idx: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = ...,
block_table: torch.Tensor | None = ...,
softmax_scale: float = ...,
causal: bool = ...,
window_size: tuple[int, int] = ..., # -1 means infinite context window
softcap: float = ...,
rotary_interleaved: bool = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
num_splits: int = ...,
return_softmax_lse: Literal[True] = ...,
*,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def sparse_attn_func(
q: tuple[int, int, int, int],
k: tuple[int, int, int, int],
v: tuple[int, int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
) -> tuple[int, int, int]: ...
@overload
def sparse_attn_func(
q: tuple[int, int, int, int],
k: tuple[int, int, int, int],
v: tuple[int, int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def sparse_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
cu_seqlens_q: torch.Tensor | None,
cu_seqlens_k: torch.Tensor | None,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
) -> tuple[int, int, int]: ...
@overload
def sparse_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
cu_seqlens_q: torch.Tensor | None,
cu_seqlens_k: torch.Tensor | None,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
def is_fa_version_supported(
fa_version: int, device: torch.device | None = None
) -> bool: ...
def fa_version_unsupported_reason(
fa_version: int, device: torch.device | None = None
) -> str | None: ...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment