Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9e12416
Commit
b9e12416
authored
May 31, 2024
by
zhuwenwen
Browse files
merge v0.4.3
parents
e5d707db
e9d3aa04
Changes
345
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2179 additions
and
341 deletions
+2179
-341
tests/worker/test_swap.py
tests/worker/test_swap.py
+13
-13
vllm/__init__.py
vllm/__init__.py
+11
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+99
-14
vllm/attention/__init__.py
vllm/attention/__init__.py
+2
-3
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+42
-41
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+410
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+174
-78
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+63
-33
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+106
-34
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+53
-24
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+96
-27
vllm/attention/layer.py
vllm/attention/layer.py
+52
-10
vllm/attention/ops/blocksparse_attention/__init__.py
vllm/attention/ops/blocksparse_attention/__init__.py
+0
-0
vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
...ops/blocksparse_attention/blocksparse_attention_kernel.py
+423
-0
vllm/attention/ops/blocksparse_attention/interface.py
vllm/attention/ops/blocksparse_attention/interface.py
+238
-0
vllm/attention/ops/blocksparse_attention/utils.py
vllm/attention/ops/blocksparse_attention/utils.py
+216
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+33
-10
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+30
-17
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+10
-0
vllm/attention/selector.py
vllm/attention/selector.py
+108
-35
No files found.
Too many changes to show.
To preserve performance only
345 of 345+
files are displayed.
Plain diff
Email patch
tests/worker/test_swap.py
View file @
b9e12416
...
@@ -54,36 +54,36 @@ def test_swap() -> None:
...
@@ -54,36 +54,36 @@ def test_swap() -> None:
a
.
cuda
(),
b
.
cuda
(),
rtol
=
0.0
,
atol
=
0.0
)
a
.
cuda
(),
b
.
cuda
(),
rtol
=
0.0
,
atol
=
0.0
)
# Test swap out.
# Test swap out.
blocks_to_swap_out
=
{
3
:
72
,
56
:
35
,
84
:
34
}
blocks_to_swap_out
=
[(
3
,
72
)
,
(
56
,
35
)
,
(
84
,
34
)]
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
[],
seq_group_metadata_list
=
[],
blocks_to_swap_in
=
{}
,
blocks_to_swap_in
=
[]
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
{}
,
blocks_to_copy
=
[]
,
)
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
blocks_to_swap_out
.
items
()
:
for
src
,
dst
in
blocks_to_swap_out
:
assert
allclose
(
gpu_key_cache
[
src
],
cpu_key_cache
[
dst
])
assert
allclose
(
gpu_key_cache
[
src
],
cpu_key_cache
[
dst
])
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
# Test swap in.
# Test swap in.
execute_model_req
.
blocks_to_swap_out
=
{}
execute_model_req
.
blocks_to_swap_out
=
[]
execute_model_req
.
blocks_to_swap_in
=
{
execute_model_req
.
blocks_to_swap_in
=
[
19
:
45
,
(
19
,
45
)
,
67
:
23
,
(
67
,
23
)
,
12
:
78
,
(
12
,
78
)
,
40
:
99
,
(
40
,
99
)
,
1
:
71
(
1
,
71
),
}
]
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
execute_model_req
.
blocks_to_swap_in
.
items
()
:
for
src
,
dst
in
execute_model_req
.
blocks_to_swap_in
:
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
vllm/__init__.py
View file @
b9e12416
...
@@ -5,22 +5,31 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
...
@@ -5,22 +5,31 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.llm
import
LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
PromptStrictInputs
,
TextPrompt
,
TokensPrompt
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__dcu_version__
from
vllm.version
import
__dcu_version__
__version__
=
"0.4.
2
"
__version__
=
"0.4.
3
"
__all__
=
[
__all__
=
[
"LLM"
,
"LLM"
,
"ModelRegistry"
,
"ModelRegistry"
,
"PromptStrictInputs"
,
"TextPrompt"
,
"TokensPrompt"
,
"SamplingParams"
,
"SamplingParams"
,
"RequestOutput"
,
"RequestOutput"
,
"CompletionOutput"
,
"CompletionOutput"
,
"EmbeddingOutput"
,
"EmbeddingRequestOutput"
,
"LLMEngine"
,
"LLMEngine"
,
"EngineArgs"
,
"EngineArgs"
,
"AsyncLLMEngine"
,
"AsyncLLMEngine"
,
"AsyncEngineArgs"
,
"AsyncEngineArgs"
,
"initialize_ray_cluster"
,
"initialize_ray_cluster"
,
"PoolingParams"
,
]
]
vllm/_custom_ops.py
View file @
b9e12416
from
typing
import
Dict
,
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -45,11 +45,17 @@ def paged_attention_v1(
...
@@ -45,11 +45,17 @@ def paged_attention_v1(
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
)
->
None
:
vllm_ops
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
vllm_ops
.
paged_attention_v1
(
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
block_size
,
max_seq_len
,
alibi_slopes
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
)
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_v2
(
def
paged_attention_v2
(
...
@@ -69,12 +75,18 @@ def paged_attention_v2(
...
@@ -69,12 +75,18 @@ def paged_attention_v2(
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
None
:
)
->
None
:
vllm_ops
.
paged_attention_v2
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
vllm_ops
.
paged_attention_v2
(
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
block_tables
,
seq_lens
,
block_size
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
alibi_slopes
,
kv_cache_dtype
,
kv_scale
,
tp_rank
,
kv_scale
)
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
)
# pos encoding ops
# pos encoding ops
...
@@ -153,6 +165,32 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -153,6 +165,32 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n
,
size_k
)
size_n
,
size_k
)
# marlin_24
def
gptq_marlin_24_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_24_gemm
(
a
,
b_q_weight
,
b_meta
,
b_scales
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
)
# cutlass
def
cutlass_scaled_mm_dq
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
a_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
])
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
vllm_ops
.
cutlass_scaled_mm_dq
(
out
,
a
,
b
,
a_scales
,
b_scales
)
return
out
# aqlm
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
...
@@ -189,8 +227,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -189,8 +227,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# def scaled_fp8_quant(
# def scaled_fp8_quant(
# input: torch.Tensor,
# input: torch.Tensor,
# scale: Optional[torch.Tensor] = None,
# scale: Optional[torch.Tensor] = None,
# batch_dim_padding: Optional[int] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensor for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# batch_dim_padding: If specified, pad the first dimension
# of the output to at least this value.
# Returns:
# Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# if batch_dim_padding:
# shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
# output = torch.empty(shape,
# device=input.device,
# dtype=torch.float8_e4m3fn)
# else:
# output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# if scale is None:
# if scale is None:
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
# vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
...
@@ -199,6 +263,24 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -199,6 +263,24 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# return output, scale
# return output, scale
# int8
# def static_scaled_int8_quant(input: torch.Tensor,
# scale: float) -> torch.Tensor:
# """
# Quantize the input tensor to int8 and return the quantized tensor.
# Args:
# input: The input tensor to be quantized to int8.
# scale: Scaling factor for the int8 quantization.
# Returns:
# torch.Tensor: Output tensor in int8.
# """
# q = torch.empty_like(input, dtype=torch.int8)
# vllm_ops.static_scaled_int8_quant(q, input, scale)
# return q
# moe
# moe
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
...
@@ -240,12 +322,15 @@ def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
...
@@ -240,12 +322,15 @@ def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
def
swap_blocks
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_mapping
:
Dict
[
int
,
int
]
)
->
None
:
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
vllm_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
def
convert_fp8
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
)
->
None
:
def
convert_fp8
(
output
:
torch
.
Tensor
,
vllm_cache_ops
.
convert_fp8
(
output
,
input
)
input
:
torch
.
Tensor
,
scale
:
float
=
1.0
,
kv_dtype
:
str
=
"fp8"
)
->
None
:
vllm_cache_ops
.
convert_fp8
(
output
,
input
,
scale
,
kv_dtype
)
#TODO: cuda_utils, custom_ar
#TODO: cuda_utils, custom_ar
vllm/attention/__init__.py
View file @
b9e12416
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
,
AttentionMetadata
)
AttentionMetadataPerStage
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
__all__
=
[
__all__
=
[
"Attention"
,
"AttentionBackend"
,
"AttentionBackend"
,
"AttentionMetadata"
,
"AttentionMetadata"
,
"Attention"
,
"Attention"
,
"get_attn_backend"
,
"get_attn_backend"
,
"AttentionMetadataPerStage"
,
]
]
vllm/attention/backends/abstract.py
View file @
b9e12416
...
@@ -9,6 +9,11 @@ import torch
...
@@ -9,6 +9,11 @@ import torch
class
AttentionBackend
(
ABC
):
class
AttentionBackend
(
ABC
):
"""Abstract class for attention backends."""
"""Abstract class for attention backends."""
@
staticmethod
@
abstractmethod
def
get_name
()
->
str
:
raise
NotImplementedError
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
get_impl_cls
()
->
Type
[
"AttentionImpl"
]:
def
get_impl_cls
()
->
Type
[
"AttentionImpl"
]:
...
@@ -16,7 +21,7 @@ class AttentionBackend(ABC):
...
@@ -16,7 +21,7 @@ class AttentionBackend(ABC):
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata
PerStage
"
:
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
staticmethod
...
@@ -34,7 +39,7 @@ class AttentionBackend(ABC):
...
@@ -34,7 +39,7 @@ class AttentionBackend(ABC):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -42,14 +47,40 @@ class AttentionBackend(ABC):
...
@@ -42,14 +47,40 @@ class AttentionBackend(ABC):
@
abstractmethod
@
abstractmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
dataclass
@
dataclass
class
AttentionMetadataPerStage
:
class
AttentionMetadata
:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills
:
int
# Number of prefill tokens.
num_prefill_tokens
:
int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens
:
int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
@
property
@
abstractmethod
def
prefill_metadata
(
self
)
->
Optional
[
"AttentionMetadata"
]:
"""Return the attention metadata that's required to run prefill
attention."""
pass
@
property
@
abstractmethod
def
decode_metadata
(
self
)
->
Optional
[
"AttentionMetadata"
]:
"""Return the attention metadata that's required to run decode
attention."""
pass
def
asdict_zerocopy
(
self
,
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
skip_fields
:
Optional
[
Set
[
str
]]
=
None
...
@@ -65,42 +96,10 @@ class AttentionMetadataPerStage:
...
@@ -65,42 +96,10 @@ class AttentionMetadataPerStage:
}
}
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadataPerStage
)
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
@
dataclass
class
AttentionMetadata
(
Generic
[
T
]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills
:
int
# Number of prefill tokens.
num_prefill_tokens
:
int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens
:
int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata
:
Optional
[
T
]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata
:
Optional
[
T
]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# The kv cache's data type.
kv_cache_dtype
:
str
def
__post_init__
(
self
):
if
self
.
num_prefill_tokens
>
0
:
assert
self
.
num_prefills
>
0
assert
self
.
prefill_metadata
is
not
None
if
self
.
num_decode_tokens
>
0
:
assert
self
.
decode_metadata
is
not
None
class
AttentionImpl
(
ABC
):
class
AttentionImpl
(
ABC
,
Generic
[
T
]
):
@
abstractmethod
@
abstractmethod
def
__init__
(
def
__init__
(
...
@@ -111,6 +110,8 @@ class AttentionImpl(ABC):
...
@@ -111,6 +110,8 @@ class AttentionImpl(ABC):
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -121,7 +122,7 @@ class AttentionImpl(ABC):
...
@@ -121,7 +122,7 @@ class AttentionImpl(ABC):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
T
,
kv_scale
:
float
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
vllm/attention/backends/blocksparse_attn.py
0 → 100644
View file @
b9e12416
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
@
dataclass
class
BlocksparseParams
:
max_seqlen
:
int
# Num q heads per tensor-parallel rank/partition
num_heads
:
int
# per TP partition
# Num kv heads per tensor-parallel rank/partition
num_kv_heads
:
int
# block size used for blocksparse attention.
# This is the block_size used in `local_blocks`, `vert_stride`.
block_size
:
int
# Number of blocks for local attention, i.e., number of
# local attended tokens / `sparse_block_size`
local_blocks
:
int
# Attend to one block per every `vert_stride` blocks.
# Controlling the sparsity
vert_stride
:
int
"""
If to use the same vertical stride offset for all heads,
i.e., attend to the same block of tokens on all heads.
By default, it is False, i.e., attention on the non-local
blocks depends on the `head_idx`, that is on
blocks satisfying
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
`block_idx = position_id // sparse_block_size`.
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
for more detail.
"""
homo_head
:
bool
=
False
# If within a group, the kv offsets that each q attends is the same or no.
homo_head_group
:
bool
=
False
# Decided by homo_head and homo_head group
head_sliding_step
:
int
=
field
(
init
=
False
)
# range of q heads to for a TP rank
active_head_range
:
Tuple
=
field
(
init
=
False
)
def
__post_init__
(
self
):
assert
self
.
block_size
>
0
assert
self
.
local_blocks
>=
0
assert
self
.
vert_stride
>=
1
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
total_heads
=
tp_size
*
self
.
num_heads
total_kv_heads
=
tp_size
*
self
.
num_kv_heads
if
self
.
homo_head
:
self
.
head_sliding_step
=
0
elif
self
.
homo_head_group
:
head_sliding_step
=
get_head_sliding_step
(
total_kv_heads
,
self
.
vert_stride
)
# negative indicates sliding along kv heads, i.e., homo q group
self
.
head_sliding_step
=
-
head_sliding_step
else
:
self
.
head_sliding_step
=
get_head_sliding_step
(
total_heads
,
self
.
vert_stride
)
self
.
active_head_range
=
(
tp_rank
*
self
.
num_heads
,
(
tp_rank
+
1
)
*
self
.
num_heads
,
)
class
BlocksparseFlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"BlocksparseFlashAttentionImpl"
]:
return
BlocksparseFlashAttentionImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"BlocksparseFlashAttentionMetadata"
:
return
BlocksparseFlashAttentionMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
BlocksparseFlashAttentionMetadata
(
AttentionMetadata
):
"""A copy of Metadata for FlashAttentionBackend,
to avoid having to install flash_attn.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
_cached_prefill_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
block_tables
is
not
None
assert
self
.
seq_start_loc
is
not
None
self
.
_cached_prefill_metadata
=
BlocksparseFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
],
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"BlocksparseFlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
_cached_decode_metadata
=
BlocksparseFlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
)
return
self
.
_cached_decode_metadata
class
BlocksparseFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
assert
blocksparse_params
is
not
None
assert
alibi_slopes
is
None
,
ValueError
(
"Alibi not support for blocksparse flash attention."
)
assert
sliding_window
is
None
,
ValueError
(
"sliding_window is invalid for blocksparse attention."
)
if
"num_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_heads"
]
=
num_heads
if
"num_kv_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_kv_heads"
]
=
num_kv_heads
or
num_heads
self
.
blocksparse_params
=
BlocksparseParams
(
**
blocksparse_params
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
alibi_slopes
=
alibi_slopes
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
local_blocks
=
self
.
blocksparse_params
.
local_blocks
self
.
vert_stride
=
self
.
blocksparse_params
.
vert_stride
self
.
sparse_block_size
=
self
.
blocksparse_params
.
block_size
self
.
head_sliding_step
=
self
.
blocksparse_params
.
head_sliding_step
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
total_num_heads
=
num_heads
*
self
.
tp_size
self
.
bs_attn
=
LocalStridedBlockSparseAttn
(
total_num_heads
,
self
.
blocksparse_params
.
max_seqlen
,
self
.
blocksparse_params
.
local_blocks
,
self
.
blocksparse_params
.
vert_stride
,
self
.
blocksparse_params
.
block_size
,
homo_head
=
self
.
blocksparse_params
.
homo_head
,
active_head_range
=
self
.
blocksparse_params
.
active_head_range
,
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
kv_scale
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
assert
kv_cache
is
None
\
or
prefill_meta
.
block_tables
is
None
\
or
prefill_meta
.
block_tables
.
numel
()
==
0
,
\
"Does not support prefix-enabled attention."
output
=
self
.
bs_attn
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
sm_scale
=
self
.
scale
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
=
PagedAttention
.
forward_decode
(
query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
self
.
blocksparse_params
.
max_seqlen
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
kv_scale
,
tp_rank
=
self
.
tp_rank
,
blocksparse_local_blocks
=
self
.
local_blocks
,
blocksparse_vert_stride
=
self
.
vert_stride
,
blocksparse_block_size
=
self
.
sparse_block_size
,
blocksparse_head_sliding_step
=
self
.
head_sliding_step
,
)
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/flash_attn.py
View file @
b9e12416
"""Attention layer with Flash and PagedAttention.
"""Attention layer with FlashAttention."""
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
flash_attn
import
flash_attn_varlen_func
from
vllm_
flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm._C
import
cache_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
)
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
def
get_name
()
->
str
:
return
"flash-attn"
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashAttentionImpl"
]:
def
get_impl_cls
()
->
Type
[
"FlashAttentionImpl"
]:
return
FlashAttentionImpl
return
FlashAttentionImpl
...
@@ -34,28 +35,36 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -34,28 +35,36 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
if
block_size
%
16
!=
0
:
num_kv_heads
,
head_size
)
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
@
staticmethod
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
cache_ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
cache_ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
@
dataclass
class
FlashAttentionMetadata
(
AttentionMetadataPerStage
,
class
FlashAttentionMetadata
(
AttentionMetadata
):
PagedAttentionMetadata
):
"""Metadata for FlashAttentionBackend.
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
NOTE: Any python object stored here is not updated when it is
...
@@ -63,9 +72,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -63,9 +72,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
dynamically, it should be stored in tensor. The tensor has to be
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
updated from `CUDAGraphRunner.forward` API.
"""
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The sequence length per sequence. Sequence length means
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
seq_lens
:
Optional
[
List
[
int
]]
...
@@ -80,14 +86,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -80,14 +86,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# |-------------------- seq_len ----------------------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch.
None for decoding.
max_query_len
:
Optional
[
int
]
max_query_len
:
Optional
[
int
]
# Maximum sequence length in the batch.
# Maximum sequence length among prefill batch. 0 if there are decoding
max_seq_len
:
Optional
[
int
]
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
# is [4, 6], it is [0, 4, 10].
sub
query_start_loc
:
Optional
[
torch
.
Tensor
]
query_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
# [4, 6], it is [0, 4, 10].
...
@@ -96,11 +106,83 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -96,11 +106,83 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# so far).
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
use_cuda_graph
:
bool
_cached_prefill_metadata
:
Optional
[
"FlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"FlashAttentionMetadata"
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"FlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
block_tables
is
not
None
assert
self
.
seq_start_loc
is
not
None
self
.
_cached_prefill_metadata
=
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
],
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"FlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
_cached_decode_metadata
=
FlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
)
return
self
.
_cached_decode_metadata
class
FlashAttentionImpl
(
AttentionImpl
):
class
FlashAttentionImpl
(
AttentionImpl
):
"""
"""
...
@@ -133,28 +215,39 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -133,28 +215,39 @@ class FlashAttentionImpl(AttentionImpl):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"FlashAttention does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
sliding_window
is
not
None
:
if
head_size
not
in
suppored_head_sizes
:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise
ValueError
(
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
"Sliding window is not supported in FlashAttention."
)
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
support_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -162,20 +255,23 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -162,20 +255,23 @@ class FlashAttentionImpl(AttentionImpl):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
[
FlashAttentionMetadata
]
,
attn_metadata
:
FlashAttentionMetadata
,
kv_scale
:
float
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention
and PagedAttention
.
"""Forward pass with FlashAttention.
Args:
Args:
query: shape = [num_tokens, num_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size
*
num_kv_heads
*
head_size]
kv_cache = [2, num_blocks, block_size
,
num_kv_heads
,
head_size]
attn_metadata: Metadata for attention.
attn_metadata: Metadata for attention.
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
kv_scale
==
1.0
,
"kv_scale is not supported in FlashAttention."
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
@@ -183,17 +279,20 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -183,17 +279,20 @@ class FlashAttentionImpl(AttentionImpl):
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
is
not
None
:
if
kv_cache
is
not
None
:
key_cache
,
value_cache
=
PagedAttention
.
split_
kv_cache
(
key_cache
=
kv_cache
[
0
]
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
value_cache
=
kv_cache
[
1
]
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
# not cached. This happens during the initial memory profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
cache_ops
.
reshape_and_cache_flash
(
value_cache
,
key
,
attn_metadata
.
slot_mapping
,
value
,
attn_metadata
.
kv_cache_dtype
,
key_cache
,
kv_scale
)
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
@@ -213,7 +312,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -213,7 +312,8 @@ class FlashAttentionImpl(AttentionImpl):
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# Prompt run.
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
if
(
kv_cache
is
None
or
prefill_meta
.
block_tables
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
# normal attention
# normal attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
...
@@ -223,8 +323,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -223,8 +323,8 @@ class FlashAttentionImpl(AttentionImpl):
v
=
value
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_seq_len
,
max_seqlen_q
=
prefill_meta
.
max_
prefill_
seq_len
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_
prefill_
seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
...
@@ -234,38 +334,34 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -234,38 +334,34 @@ class FlashAttentionImpl(AttentionImpl):
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
assert
prefill_meta
.
seq_lens
is
not
None
# deal with different data types between KV and FP8 KV cache,
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
# to be addressed separately.
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
q
=
query
,
query
,
k
=
key_cache
,
key
,
v
=
value_cache
,
value
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
key_cache
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
value_cache
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
prefill_meta
.
block_tables
,
max_seqlen_k
=
max_seq_len
,
prefill_meta
.
subquery_start_loc
,
softmax_scale
=
self
.
scale
,
prefill_meta
.
seq_lens_tensor
,
causal
=
True
,
prefill_meta
.
context_lens_tensor
,
alibi_slopes
=
self
.
alibi_slopes
,
prefill_meta
.
max_query_len
,
block_table
=
prefill_meta
.
block_tables
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decod
e
(
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcach
e
(
decode_query
,
decode_query
.
unsqueeze
(
1
)
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
decode_meta
.
block_tables
,
block_table
=
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
attn_metadata
.
kv_cache_dtype
,
causal
=
True
,
self
.
num_kv_heads
,
alibi_slopes
=
self
.
alibi_slopes
,
self
.
scale
,
).
squeeze
(
1
)
self
.
alibi_slopes
,
kv_scale
,
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/flashinfer.py
View file @
b9e12416
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
import
flashinfer
import
flashinfer
from
flash_attn
import
flash_attn_varlen_func
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
except
ImportError
:
flashinfer
=
None
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
import
torch
import
torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
)
AttentionMetadataPerStage
)
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"flashinfer"
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashInferImpl"
]:
def
get_impl_cls
()
->
Type
[
"FlashInferImpl"
]:
return
FlashInferImpl
return
FlashInferImpl
...
@@ -41,14 +38,14 @@ class FlashInferBackend(AttentionBackend):
...
@@ -41,14 +38,14 @@ class FlashInferBackend(AttentionBackend):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -58,9 +55,10 @@ class FlashInferBackend(AttentionBackend):
...
@@ -58,9 +55,10 @@ class FlashInferBackend(AttentionBackend):
@
dataclass
@
dataclass
class
FlashInferMetadata
(
AttentionMetadataPerStage
):
class
FlashInferMetadata
(
AttentionMetadata
):
# Maximum sequence length among prefill batch. 0 if there are decoding
is_prompt
:
bool
# requests only.
max_prefill_seq_len
:
int
use_cuda_graph
:
bool
=
False
use_cuda_graph
:
bool
=
False
...
@@ -69,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage):
...
@@ -69,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage):
# Metadata for the prefill stage since we still
# Metadata for the prefill stage since we still
# use flash attention for prefill.
# use flash attention for prefill.
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
max_seq_len
:
Optional
[
int
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# Metadata for the decode stage
# Metadata for the decode stage
...
@@ -115,7 +112,8 @@ class FlashInferMetadata(AttentionMetadataPerStage):
...
@@ -115,7 +112,8 @@ class FlashInferMetadata(AttentionMetadataPerStage):
# When using flashinfer, we are also creating the FlashInferMetadata,
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
# post_init if it's the prefill phase.
if
not
self
.
is_prompt
:
if
self
.
num_prefills
==
0
:
assert
self
.
num_decode_tokens
>
0
self
.
decode_wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
self
.
decode_wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
self
.
workspace_buffer
,
"NHD"
)
self
.
decode_wrapper
.
begin_forward
(
self
.
decode_wrapper
.
begin_forward
(
...
@@ -140,6 +138,24 @@ class FlashInferMetadata(AttentionMetadataPerStage):
...
@@ -140,6 +138,24 @@ class FlashInferMetadata(AttentionMetadataPerStage):
skip_fields
.
add
(
'decode_wrapper'
)
skip_fields
.
add
(
'decode_wrapper'
)
return
super
().
asdict_zerocopy
(
skip_fields
)
return
super
().
asdict_zerocopy
(
skip_fields
)
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"FlashInferMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_decode_tokens
==
0
:
assert
self
.
num_prefills
>
0
return
self
return
None
@
property
def
decode_metadata
(
self
)
->
Optional
[
"FlashInferMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_prefills
>
0
:
assert
self
.
num_decode_tokens
==
0
return
None
return
self
class
FlashInferImpl
(
AttentionImpl
):
class
FlashInferImpl
(
AttentionImpl
):
...
@@ -148,23 +164,36 @@ class FlashInferImpl(AttentionImpl):
...
@@ -148,23 +164,36 @@ class FlashInferImpl(AttentionImpl):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
not
None
:
if
sliding_window
is
not
None
:
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
alibi_slopes
=
alibi_slopes
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
scale
=
scale
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
attn_metadata
:
AttentionMetadata
[
FlashInferMetadata
],
kv_scale
:
float
):
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
FlashInferMetadata
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
assert
kv_scale
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -185,10 +214,11 @@ class FlashInferImpl(AttentionImpl):
...
@@ -185,10 +214,11 @@ class FlashInferImpl(AttentionImpl):
kv_cache
[:,
0
],
kv_cache
[:,
0
],
kv_cache
[:,
1
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
)
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
assert
prefill_meta
.
block_tables
is
not
None
assert
prefill_meta
.
block_tables
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
...
@@ -197,8 +227,8 @@ class FlashInferImpl(AttentionImpl):
...
@@ -197,8 +227,8 @@ class FlashInferImpl(AttentionImpl):
v
=
value
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_seq_len
,
max_seqlen_q
=
prefill_meta
.
max_
prefill_
seq_len
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_
prefill_
seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
b9e12416
"""Attention layer ROCm GPUs."""
"""Attention layer ROCm GPUs."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
)
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -17,6 +16,10 @@ logger = init_logger(__name__)
...
@@ -17,6 +16,10 @@ logger = init_logger(__name__)
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"rocm-flash-attn"
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
Type
[
"ROCmFlashAttentionImpl"
]:
def
get_impl_cls
()
->
Type
[
"ROCmFlashAttentionImpl"
]:
return
ROCmFlashAttentionImpl
return
ROCmFlashAttentionImpl
...
@@ -39,21 +42,20 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -39,21 +42,20 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
@
dataclass
class
ROCmFlashAttentionMetadata
(
AttentionMetadataPerStage
,
class
ROCmFlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
PagedAttentionMetadata
):
"""Metadata for FlashAttentionBackend.
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
NOTE: Any python object stored here is not updated when it is
...
@@ -61,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -61,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
dynamically, it should be stored in tensor. The tensor has to be
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
updated from `CUDAGraphRunner.forward` API.
"""
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The sequence length per sequence. Sequence length means
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
seq_lens
:
Optional
[
List
[
int
]]
...
@@ -78,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -78,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# |-------------------- seq_len ----------------------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch.
None for decoding.
max_query_len
:
Optional
[
int
]
max_query_len
:
Optional
[
int
]
# Maximum sequence length in the batch.
# Maximum sequence length among prefill batch. 0 if there are decoding
max_seq_len
:
Optional
[
int
]
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
# is [4, 6], it is [0, 4, 10].
sub
query_start_loc
:
Optional
[
torch
.
Tensor
]
query_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
# [4, 6], it is [0, 4, 10].
...
@@ -98,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
...
@@ -98,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# (batch_size,) A tensor of context lengths (tokens that are computed
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
_cached_prefill_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"ROCmFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
block_tables
is
not
None
assert
self
.
seq_start_loc
is
not
None
self
.
_cached_prefill_metadata
=
ROCmFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
],
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"ROCmFlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
_cached_decode_metadata
=
ROCmFlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
)
return
self
.
_cached_decode_metadata
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
...
@@ -131,28 +197,33 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -131,28 +197,33 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"ROCFlashAttention does not support blocksparse attention."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
suppor
t
ed_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
if
head_size
not
in
suppor
t
ed_head_sizes
:
raise
ValueError
(
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
f
"Supported head sizes are:
{
suppor
t
ed_head_sizes
}
."
)
self
.
use_naive_attn
=
False
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
...
@@ -163,8 +234,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -163,8 +234,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
attn_func
=
triton_attention
self
.
attn_func
=
triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
else
:
else
:
# if not using triton, navi3x not use flash-attn either
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
if
torch
.
cuda
.
get_device_capability
()[
0
]
==
11
:
# either
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
self
.
use_naive_attn
=
True
self
.
use_naive_attn
=
True
else
:
else
:
try
:
try
:
...
@@ -192,7 +264,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -192,7 +264,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
[
ROCmFlashAttentionMetadata
]
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -225,7 +297,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -225,7 +297,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache
,
key_cache
,
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
,
kv_scale
,
)
)
...
@@ -260,8 +332,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -260,8 +332,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
None
,
None
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_seq_len
,
prefill_meta
.
max_
prefill_
seq_len
,
prefill_meta
.
max_seq_len
,
prefill_meta
.
max_
prefill_
seq_len
,
True
,
True
,
self
.
scale
,
self
.
scale
,
)
)
...
@@ -284,8 +356,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -284,8 +356,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v
=
value
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_seq_len
,
max_seqlen_q
=
prefill_meta
.
max_
prefill_
seq_len
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_
prefill_
seq_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
)
)
...
@@ -302,7 +374,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -302,7 +374,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache
,
key_cache
,
value_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
prefill_meta
.
sub
query_start_loc
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
context_lens_tensor
,
prefill_meta
.
context_lens_tensor
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_query_len
,
...
@@ -318,8 +390,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -318,8 +390,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_seq_len
,
decode_meta
.
max_
decode_
seq_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
...
...
vllm/attention/backends/torch_sdpa.py
View file @
b9e12416
""" Attention layer with torch scaled_dot_product_attention
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
and PagedAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
torch.nn.functional
import
scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
)
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
class
TorchSDPABackend
(
AttentionBackend
):
class
TorchSDPABackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"torch-sdpa"
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
Type
[
"TorchSDPABackendImpl"
]:
def
get_impl_cls
()
->
Type
[
"TorchSDPABackendImpl"
]:
return
TorchSDPABackendImpl
return
TorchSDPABackendImpl
...
@@ -37,21 +40,20 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -37,21 +40,20 @@ class TorchSDPABackend(AttentionBackend):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
,
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
AttentionMetadataPerStage
):
"""Metadata for TorchSDPABackend.
"""Metadata for TorchSDPABackend.
"""
"""
# Currently, input sequences can only contain all prompts
# Currently, input sequences can only contain all prompts
...
@@ -68,37 +70,64 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
...
@@ -68,37 +70,64 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# will not appear in the __repr__ and __init__
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_decode_tokens
==
0
:
assert
self
.
num_prefills
>
0
return
self
return
None
@
property
def
decode_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_prefills
>
0
:
assert
self
.
num_decode_tokens
==
0
return
None
class
TorchSDPABackendImpl
(
AttentionImpl
):
return
self
class
TorchSDPABackendImpl
(
AttentionImpl
[
TorchSDPAMetadata
]):
def
__init__
(
def
__init__
(
self
,
self
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
assert
len
(
alibi_slopes
)
==
num_heads
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
self
.
sliding_window
=
sliding_window
or
self
.
sliding_window
is
not
None
)
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
if
head_size
not
in
suppored_head_sizes
:
or
self
.
sliding_window
is
not
None
)
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -107,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -107,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
kv_scale
:
float
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
"""Forward pass with torch SDPA and PagedAttention.
...
@@ -120,6 +149,7 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -120,6 +149,7 @@ class TorchSDPABackendImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
kv_scale
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
@@ -132,8 +162,7 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -132,8 +162,7 @@ class TorchSDPABackendImpl(AttentionImpl):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
)
kv_scale
)
if
attn_metadata
.
is_prompt
:
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
seq_lens
is
not
None
...
@@ -190,8 +219,8 @@ class TorchSDPABackendImpl(AttentionImpl):
...
@@ -190,8 +219,8 @@ class TorchSDPABackendImpl(AttentionImpl):
value_cache
,
value_cache
,
attn_metadata
.
block_tables
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
max_seq_len
,
attn_metadata
.
max_
decode_
seq_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
...
...
vllm/attention/backends/xformers.py
View file @
b9e12416
"""Attention layer with xFormers and PagedAttention."""
"""Attention layer with xFormers and PagedAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
@@ -9,8 +9,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
...
@@ -9,8 +9,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias
)
LowerTriangularMaskWithTensorBias
)
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
)
AttentionMetadataPerStage
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -20,6 +19,10 @@ logger = init_logger(__name__)
...
@@ -20,6 +19,10 @@ logger = init_logger(__name__)
class
XFormersBackend
(
AttentionBackend
):
class
XFormersBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"xformers"
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
Type
[
"XFormersImpl"
]:
def
get_impl_cls
()
->
Type
[
"XFormersImpl"
]:
return
XFormersImpl
return
XFormersImpl
...
@@ -49,13 +52,13 @@ class XFormersBackend(AttentionBackend):
...
@@ -49,13 +52,13 @@ class XFormersBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
@
dataclass
class
XFormersMetadata
(
AttentionMetadata
PerStage
,
PagedAttentionMetadata
):
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
"""Metadata for XFormersbackend.
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
NOTE: Any python object stored here is not updated when it is
...
@@ -63,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
...
@@ -63,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
updated from `CUDAGraphRunner.forward` API.
"""
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
# (batch_size,). The sequence length per sequence. Sequence length means
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
seq_lens
:
Optional
[
List
[
int
]]
...
@@ -79,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
...
@@ -79,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch.
None for decoding.
max_query_len
:
Optional
[
int
]
max_query_len
:
Optional
[
int
]
# FIXME: It is for flash attn.
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
# Maximum sequence length among prefill batch. 0 if there are decoding
max_seq_len
:
Optional
[
int
]
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
# is [4, 6], it is [0, 4, 10].
sub
query_start_loc
:
Optional
[
torch
.
Tensor
]
query_start_loc
:
Optional
[
torch
.
Tensor
]
# FIXME: It is for flash attn.
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# the batch, used to index into sequence. E.g., if the sequence length is
...
@@ -101,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
...
@@ -101,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
use_cuda_graph
:
bool
_cached_prefill_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
...
@@ -110,8 +116,68 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
...
@@ -110,8 +116,68 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# will not appear in the __repr__ and __init__
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
self
.
attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
@
property
class
XFormersImpl
(
AttentionImpl
):
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
block_tables
is
not
None
self
.
_cached_prefill_metadata
=
XFormersMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
None
,
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
_cached_decode_metadata
=
XFormersMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
)
return
self
.
_cached_decode_metadata
class
XFormersImpl
(
AttentionImpl
[
XFormersMetadata
]):
"""
"""
If the input tensors contain prompt tokens, the layout is as follows:
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--------------- num_prefill_tokens ----------------->|
...
@@ -142,18 +208,23 @@ class XFormersImpl(AttentionImpl):
...
@@ -142,18 +208,23 @@ class XFormersImpl(AttentionImpl):
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
"XFormer does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -170,8 +241,8 @@ class XFormersImpl(AttentionImpl):
...
@@ -170,8 +241,8 @@ class XFormersImpl(AttentionImpl):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
XFormersMetadata
]
,
attn_metadata
:
"
XFormersMetadata
"
,
kv_scale
:
float
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
...
@@ -184,7 +255,6 @@ class XFormersImpl(AttentionImpl):
...
@@ -184,7 +255,6 @@ class XFormersImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -199,8 +269,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -199,8 +269,7 @@ class XFormersImpl(AttentionImpl):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
)
kv_scale
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
...
@@ -240,7 +309,7 @@ class XFormersImpl(AttentionImpl):
...
@@ -240,7 +309,7 @@ class XFormersImpl(AttentionImpl):
key_cache
,
key_cache
,
value_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
prefill_meta
.
sub
query_start_loc
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
context_lens_tensor
,
prefill_meta
.
context_lens_tensor
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_query_len
,
...
@@ -257,8 +326,8 @@ class XFormersImpl(AttentionImpl):
...
@@ -257,8 +326,8 @@ class XFormersImpl(AttentionImpl):
value_cache
,
value_cache
,
decode_meta
.
block_tables
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
seq_lens_tensor
,
decode_meta
.
max_seq_len
,
decode_meta
.
max_
decode_
seq_len
,
attn_metadata
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
...
...
vllm/attention/layer.py
View file @
b9e12416
"""Attention layer."""
"""Attention layer."""
from
typing
import
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention.backends.abstract
import
(
AttentionMetadata
,
from
vllm.attention.backends.abstract
import
AttentionMetadata
AttentionMetadataPerStage
)
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
...
@@ -28,13 +30,53 @@ class Attention(nn.Module):
...
@@ -28,13 +30,53 @@ class Attention(nn.Module):
scale
:
float
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
backend
=
get_attn_backend
(
torch
.
get_default_dtype
())
if
cache_config
is
not
None
:
impl_cls
=
self
.
backend
.
get_impl_cls
()
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
sliding_window
=
cache_config
.
sliding_window
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
sliding_window
=
None
if
num_kv_heads
is
None
:
num_kv_heads
=
num_heads
# The default kv_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along
# with the model weights.
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
_kv_scale
=
1.0
quant_method
=
quant_config
.
get_quant_method
(
self
)
if
quant_config
else
None
if
quant_method
is
not
None
:
if
self
.
kv_cache_dtype
==
"fp8_e5m2"
:
raise
ValueError
(
"fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints."
)
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back
# to self._kv_scale in a native float32 value after weight loading.
self
.
quant_method
=
quant_method
self
.
quant_method
.
create_weights
(
self
)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
num_heads
,
head_size
,
num_kv_heads
,
sliding_window
,
dtype
,
kv_cache_dtype
,
block_size
,
blocksparse_params
is
not
None
)
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
)
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -42,15 +84,15 @@ class Attention(nn.Module):
...
@@ -42,15 +84,15 @@ class Attention(nn.Module):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
AttentionMetadataPerStage
],
attn_metadata
:
AttentionMetadata
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
kv_scale
)
self
.
_
kv_scale
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
s
+=
f
", num_heads=
{
self
.
impl
.
num_heads
}
"
# type: ignore
s
+=
f
", num_heads=
{
self
.
impl
.
num_heads
}
"
# type: ignore
s
+=
f
", num_kv_heads=
{
self
.
impl
.
num_kv_heads
}
"
# type: ignore
s
+=
f
", num_kv_heads=
{
self
.
impl
.
num_kv_heads
}
"
# type: ignore
s
+=
f
", scale=
{
self
.
impl
.
scale
}
"
# type: ignore
s
+=
f
", scale=
{
self
.
impl
.
scale
}
"
# type: ignore
s
+=
f
", backend=
{
self
.
impl
.
__class__
.
__name__
}
"
return
s
return
s
vllm/attention/ops/blocksparse_attention/__init__.py
0 → 100644
View file @
b9e12416
vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
0 → 100644
View file @
b9e12416
import
torch
import
triton
import
triton.language
as
tl
def
blocksparse_flash_attn_varlen_fwd
(
q
,
k
,
v
,
# (#tokens, n_heads, head_size)
cu_seqlens_k
,
cu_seqlens_q
,
sm_scale
,
sparse_layout
,
*
,
block_size
=
64
,
q_block_size
=
None
,
max_seqlen
=
None
):
# split q to blocks
assert
isinstance
(
sparse_layout
,
(
list
,
tuple
))
_
,
n_heads
,
head_size
=
q
.
shape
batch_size
=
cu_seqlens_k
.
size
(
0
)
-
1
q_block_size
=
q_block_size
or
block_size
assert
q
.
dim
()
==
k
.
dim
()
==
v
.
dim
()
==
3
assert
q
.
size
(
1
)
%
k
.
size
(
1
)
==
0
assert
q
.
size
(
2
)
==
k
.
size
(
2
)
# TODO(linxihui): allow k, v to have different head_size
assert
k
.
shape
==
v
.
shape
assert
cu_seqlens_k
.
dim
()
==
1
q_k_ratio
=
q
.
size
(
1
)
//
k
.
size
(
1
)
if
cu_seqlens_q
is
None
:
if
q
.
size
(
0
)
==
batch_size
:
# decoding only
cu_seqlens_q
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
cu_seqlens_k
.
dtype
,
device
=
cu_seqlens_k
.
device
,
)
elif
q
.
size
(
0
)
==
k
.
size
(
0
):
cu_seqlens_q
=
cu_seqlens_k
else
:
raise
ValueError
(
"cu_seqlens_q must be specified
\
if it mix of prefilling and decoding."
)
else
:
assert
cu_seqlens_k
.
size
(
0
)
==
cu_seqlens_q
.
size
(
0
)
# switch to use cpu to avoid too many kernel launches when iterated over
q_lens
=
(
cu_seqlens_q
[
1
:]
-
cu_seqlens_q
[:
-
1
]).
cpu
()
k_lens
=
(
cu_seqlens_k
[
1
:]
-
cu_seqlens_k
[:
-
1
]).
cpu
()
assert
torch
.
logical_or
(
q_lens
==
1
,
k_lens
==
q_lens
).
all
(),
(
"length of q should either be 1 (decoding) or same as k (prefilling)."
)
if
max_seqlen
:
assert
k_lens
.
max
()
<=
max_seqlen
n_blocks
=
(
q_lens
+
q_block_size
-
1
)
//
q_block_size
q_batch_ids
=
torch
.
tensor
(
[
i
for
i
,
n
in
enumerate
(
n_blocks
)
for
_
in
range
(
n
)],
dtype
=
cu_seqlens_q
.
dtype
,
device
=
cu_seqlens_q
.
device
,
)
q_start_sids
=
torch
.
tensor
(
[
i
*
q_block_size
for
n
in
n_blocks
for
i
in
range
(
n
)],
dtype
=
cu_seqlens_q
.
dtype
,
device
=
cu_seqlens_q
.
device
,
)
out
=
q
.
new_empty
(
q
.
shape
)
cu_seqlens_q
=
cu_seqlens_q
.
contiguous
()
cu_seqlens_k
=
cu_seqlens_k
.
contiguous
()
layout_crow_indices
,
layout_col_indices
=
sparse_layout
block_d
=
triton
.
next_power_of_2
(
head_size
)
decoding_only
=
(
q_lens
==
1
).
all
().
item
()
grid
=
(
len
(
q_start_sids
),
n_heads
,
1
)
_fwd_kernel_batch_inference
[
grid
](
q
,
k
,
v
,
out
,
sm_scale
,
cu_seqlens_q
[:
-
1
],
cu_seqlens_q
[
1
:],
cu_seqlens_k
[:
-
1
],
cu_seqlens_k
[
1
:],
q_batch_ids
,
q_start_sids
,
0
,
*
q
.
stride
(),
0
,
*
k
.
stride
(),
0
,
*
v
.
stride
(),
0
,
*
out
.
stride
(),
layout_crow_indices
,
layout_col_indices
,
*
layout_crow_indices
.
stride
(),
*
layout_col_indices
.
stride
(),
q_k_ratio
,
HAS_BATCH_DIM
=
False
,
D_HEAD
=
head_size
,
BLOCK_M
=
q_block_size
,
BLOCK_N
=
block_size
,
BLOCK_D
=
block_d
,
BLOCK_M_LOADING
=
(
16
if
decoding_only
else
q_block_size
),
# smaller for decoding
EVEN_D
=
block_d
==
head_size
,
num_warps
=
1
if
decoding_only
else
4
,
num_stages
=
3
)
return
out
@
triton
.
jit
def
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_col_idx
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
LAST_K_BLOCK
:
tl
.
constexpr
,
BLOCK_M_LOADING
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
EVEN_D
:
tl
.
constexpr
,
M_LT_N
:
tl
.
constexpr
,
):
k_block_id
=
tl
.
load
(
layout_col_ptr
+
off_h
*
layout_col_stride_h
+
k_block_col_idx
*
layout_col_stride_m
).
to
(
tl
.
int32
)
start_n
=
k_block_id
*
BLOCK_N
if
LAST_K_BLOCK
:
if
EVEN_D
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_n
[
None
,
:]
+
start_n
<
k_seqlen
,
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
(
offs_n
[
None
,
:]
+
start_n
<
k_seqlen
)
&
(
offs_d
[:,
None
]
<
D_HEAD
),
)
else
:
if
EVEN_D
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_d
[:,
None
]
<
D_HEAD
)
qk
=
tl
.
zeros
([
BLOCK_M_LOADING
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if
LAST_K_BLOCK
|
M_LT_N
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
+
past_len
>=
(
start_n
+
offs_n
[
None
,
:]),
0
,
float
(
"-inf"
),
)
# flash-attn2
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
p
=
tl
.
math
.
exp2
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_ij
)
acc
=
acc
*
alpha
[:,
None
]
# update m_i
m_i
=
m_ij
l_i
=
l_i
*
alpha
+
l_ij
p
=
p
.
to
(
Q
.
dtype
.
element_ty
)
# update acc
if
LAST_K_BLOCK
:
if
EVEN_D
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_n
[:,
None
]
+
start_n
<
k_seqlen
,
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
(
offs_n
[:,
None
]
+
start_n
<
k_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
)
else
:
if
EVEN_D
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_d
[
None
,
:]
<
D_HEAD
)
acc
+=
tl
.
dot
(
p
,
v
)
return
acc
,
l_i
,
m_i
@
triton
.
heuristics
({
"M_LT_N"
:
lambda
kwargs
:
kwargs
[
"BLOCK_M"
]
<
kwargs
[
"BLOCK_N"
],
})
@
triton
.
jit
def
_fwd_kernel_batch_inference
(
Q
,
K
,
V
,
Out
,
sm_scale
,
q_batch_starts
,
q_batch_ends
,
k_batch_starts
,
k_batch_ends
,
q_batch_ids
,
q_start_sids
,
stride_qb
,
stride_qt
,
stride_qh
,
stride_qd
,
stride_kb
,
stride_kt
,
stride_kh
,
stride_kd
,
stride_vb
,
stride_vt
,
stride_vh
,
stride_vd
,
stride_ob
,
stride_ot
,
stride_oh
,
stride_od
,
layout_crow_ptr
,
layout_col_ptr
,
layout_crow_stride_h
,
layout_crow_stride_m
,
layout_col_stride_h
,
layout_col_stride_m
,
q_k_ratio
,
HAS_BATCH_DIM
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
BLOCK_M_LOADING
:
tl
.
constexpr
,
EVEN_D
:
tl
.
constexpr
,
M_LT_N
:
tl
.
constexpr
,
):
"""
NOTATION:
pid: position id
sid: storage id
sbid: storage block id
pbid: position block id
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
TODO(linxihui):
Optimize grouped-attn
"""
off_zm
=
tl
.
program_id
(
0
)
off_h
=
tl
.
program_id
(
1
)
off_h_for_kv
=
off_h
//
q_k_ratio
if
HAS_BATCH_DIM
:
off_z
=
tl
.
program_id
(
2
)
Q
+=
off_z
*
stride_qb
K
+=
off_z
*
stride_kb
V
+=
off_z
*
stride_vb
Out
+=
off_z
*
stride_ob
start_m
=
off_zm
q_start_sid
=
start_m
*
BLOCK_M
# always 0 for decoding
else
:
off_z
=
tl
.
load
(
q_batch_ids
+
off_zm
).
to
(
tl
.
int32
)
# [0, 0, 0, 1]
q_start_sid
=
tl
.
load
(
q_start_sids
+
off_zm
)
start_m
=
q_start_sid
//
BLOCK_M
# q_sbid
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M_LOADING
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
q_cu_start
=
tl
.
load
(
q_batch_starts
+
off_z
).
to
(
tl
.
int32
)
q_seqlen
=
tl
.
load
(
q_batch_ends
+
off_z
).
to
(
tl
.
int32
)
-
q_cu_start
k_cu_start
=
tl
.
load
(
k_batch_starts
+
off_z
).
to
(
tl
.
int32
)
k_seqlen
=
tl
.
load
(
k_batch_ends
+
off_z
).
to
(
tl
.
int32
)
-
k_cu_start
past_len
=
k_seqlen
-
q_seqlen
Q
+=
q_cu_start
*
stride_qt
+
off_h
*
stride_qh
K
+=
k_cu_start
*
stride_kt
+
off_h_for_kv
*
stride_kh
V
+=
k_cu_start
*
stride_vt
+
off_h_for_kv
*
stride_vh
Out
+=
q_cu_start
*
stride_ot
+
off_h
*
stride_oh
q_pbid
=
(
past_len
+
q_start_sid
)
//
BLOCK_M
if
EVEN_D
:
q
=
tl
.
load
(
Q
+
offs_m
[:,
None
]
*
stride_qt
+
offs_d
[
None
,
:]
*
stride_qd
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
,
)
else
:
q
=
tl
.
load
(
Q
+
offs_m
[:,
None
]
*
stride_qt
+
offs_d
[
None
,
:]
*
stride_qd
,
mask
=
(
offs_m
[:,
None
]
<
q_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
other
=
0
,
)
sparse_crow_ptr
=
(
layout_crow_ptr
+
off_h
*
layout_crow_stride_h
+
q_pbid
*
layout_crow_stride_m
)
# TODO(linxihui): load at once, with any Triton version
# that supports `tl.split`, e.g., Triton 3.0
k_block_start
=
tl
.
load
(
sparse_crow_ptr
).
to
(
tl
.
int32
)
k_block_end
=
tl
.
load
(
sparse_crow_ptr
+
1
).
to
(
tl
.
int32
)
m_i
=
tl
.
zeros
([
BLOCK_M_LOADING
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M_LOADING
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M_LOADING
,
BLOCK_D
],
dtype
=
tl
.
float32
)
k_ptrs
=
K
+
offs_n
[
None
,
:]
*
stride_kt
+
offs_d
[:,
None
]
*
stride_kd
v_ptrs
=
V
+
offs_n
[:,
None
]
*
stride_vt
+
offs_d
[
None
,
:]
*
stride_vd
sm_scale
*=
(
1.44269504
# 1/log2 as we use base2 for exponential and logarithm
)
for
k_block_col_idx
in
range
(
k_block_start
,
k_block_end
-
1
):
acc
,
l_i
,
m_i
=
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_col_idx
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
False
,
BLOCK_M_LOADING
,
BLOCK_N
,
D_HEAD
,
EVEN_D
,
M_LT_N
,
)
acc
,
l_i
,
m_i
=
_fwd_kernel_inner
(
acc
,
l_i
,
m_i
,
q
,
Q
,
k_block_end
-
1
,
layout_col_ptr
,
layout_col_stride_h
,
layout_col_stride_m
,
k_ptrs
,
v_ptrs
,
off_h
,
offs_m
,
offs_n
,
offs_d
,
stride_kt
,
stride_vt
,
sm_scale
,
k_seqlen
,
past_len
,
True
,
BLOCK_M_LOADING
,
BLOCK_N
,
D_HEAD
,
EVEN_D
,
M_LT_N
,
)
# flash-attn 2
m_i
+=
tl
.
math
.
log2
(
l_i
)
acc
=
acc
/
l_i
[:,
None
]
# write output
if
EVEN_D
:
tl
.
store
(
Out
+
offs_m
[:,
None
]
*
stride_ot
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
,
)
else
:
tl
.
store
(
Out
+
offs_m
[:,
None
]
*
stride_ot
+
offs_d
[
None
,
:]
*
stride_od
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
q_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
)
vllm/attention/ops/blocksparse_attention/interface.py
0 → 100644
View file @
b9e12416
import
math
import
torch
from
vllm.utils
import
is_cpu
,
is_hip
from
.utils
import
(
dense_to_crow_col
,
get_head_sliding_step
,
get_sparse_attn_mask
)
IS_COMPUTE_8_OR_ABOVE
=
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
)
if
IS_COMPUTE_8_OR_ABOVE
:
from
.blocksparse_attention_kernel
import
blocksparse_flash_attn_varlen_fwd
class
LocalStridedBlockSparseAttn
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_heads
,
max_seqlen
,
local_blocks
,
vert_stride
,
block_size
,
device
=
None
,
dtype
=
None
,
homo_head
=
False
,
active_head_range
=
None
,
q_block_size
=
None
,
use_spda
=
None
,
):
super
().
__init__
()
if
use_spda
is
None
:
use_spda
=
is_hip
()
or
is_cpu
()
or
not
\
IS_COMPUTE_8_OR_ABOVE
device
=
device
or
(
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
device
)
# NOTE: vllm CPU backend support BF16 instead of FP16.
dtype
=
dtype
or
(
torch
.
bfloat16
if
IS_COMPUTE_8_OR_ABOVE
or
device
.
type
==
"cpu"
else
torch
.
half
)
self
.
n_heads
=
n_heads
self
.
max_seqlen
=
max_seqlen
self
.
local_blocks
=
local_blocks
self
.
vert_stride
=
vert_stride
self
.
use_spda
=
use_spda
self
.
dtype
=
dtype
self
.
device
=
device
self
.
block_size
=
block_size
self
.
q_block_size
=
q_block_size
self
.
homo_head
=
homo_head
self
.
active_head_range
=
active_head_range
self
.
head_sliding_step
=
get_head_sliding_step
(
n_heads
,
vert_stride
,
homo_head
)
sparse_layout
,
sparse_pattern
,
self
.
dense_attn_mask
=
(
self
.
get_attn_pattern
(
dtype
,
device
))
if
q_block_size
is
not
None
and
q_block_size
!=
block_size
:
if
q_block_size
>
block_size
:
assert
q_block_size
%
block_size
==
0
blocks_to_merge
=
q_block_size
//
block_size
shape
=
sparse_pattern
.
shape
sparse_pattern
=
sparse_pattern
.
view
(
shape
[
0
],
-
1
,
blocks_to_merge
,
shape
[
-
1
])
sparse_pattern
=
sparse_pattern
.
sum
(
2
)
sparse_layout
=
dense_to_crow_col
(
sparse_pattern
)
else
:
raise
ValueError
(
"Does not support smaller q_block_size. It will be slower."
)
self
.
sparse_layout
=
sparse_layout
def
get_attn_pattern
(
self
,
dtype
,
device
):
sparse_layout
,
sparse_pattern
,
dense_attn_mask
=
get_sparse_attn_mask
(
self
.
n_heads
,
self
.
max_seqlen
,
self
.
max_seqlen
,
dtype
,
device
,
block_size
=
self
.
block_size
,
local_blocks
=
self
.
local_blocks
,
vert_stride
=
self
.
vert_stride
,
homo_head
=
self
.
homo_head
,
return_dense
=
self
.
use_spda
,
dense_mask_type
=
"bias"
,
)
if
(
not
self
.
homo_head
)
and
(
self
.
active_head_range
is
not
None
):
assert
isinstance
(
self
.
active_head_range
,
tuple
)
assert
(
len
(
self
.
active_head_range
)
==
2
)
h_start
,
h_end
=
self
.
active_head_range
sparse_layout
=
tuple
(
x
[
h_start
:
h_end
]
for
x
in
sparse_layout
)
if
self
.
use_spda
:
dense_attn_mask
=
dense_attn_mask
[
h_start
:
h_end
]
return
sparse_layout
,
sparse_pattern
,
dense_attn_mask
def
varlen_attn
(
self
,
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
None
,
sm_scale
=
None
):
"""
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,),
indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify is when q is a mix of
prefilling and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
"""
assert
(
IS_COMPUTE_8_OR_ABOVE
),
"Requires compute capability of 8 or above (Ampere or newer) to use
\
Triton kernel."
sm_scale
=
sm_scale
or
1.0
/
math
.
sqrt
(
q
.
size
(
-
1
))
return
blocksparse_flash_attn_varlen_fwd
(
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
,
sm_scale
,
self
.
sparse_layout
,
block_size
=
self
.
block_size
,
q_block_size
=
self
.
q_block_size
,
max_seqlen
=
self
.
max_seqlen
,
)
@
staticmethod
def
transpose_and_pad
(
x
,
cu_seqlens
,
maxlen
,
head_repeats
=
1
):
"""
:param x: (total_tokens, n_heads, head_size)
:return: (batch, n_heads, length, head_size)
"""
x_padded
=
x
.
new_empty
(
len
(
cu_seqlens
)
-
1
,
x
.
size
(
1
),
head_repeats
,
maxlen
,
x
.
size
(
2
))
cu_seqlens
=
cu_seqlens
.
cpu
()
for
i
,
(
s
,
e
)
in
enumerate
(
zip
(
cu_seqlens
[:
-
1
],
cu_seqlens
[
1
:])):
x_padded
[
i
,
:,
:,
:
e
-
s
].
copy_
(
x
[
s
:
e
].
transpose
(
0
,
1
).
unsqueeze
(
1
))
return
x_padded
.
flatten
(
1
,
2
)
@
staticmethod
def
transpose_and_unpad
(
x_padded
,
cu_seqlens
):
"""
:param x_padded: (batch, n_heads, length, head_size)
:return: (total_tokens, n_heads, head_size)
"""
cu_seqlens
=
cu_seqlens
.
cpu
()
total_n_tokens
=
cu_seqlens
[
-
1
]
x
=
x_padded
.
new_empty
(
total_n_tokens
,
x_padded
.
size
(
1
),
x_padded
.
size
(
3
))
for
i
,
(
s
,
e
)
in
enumerate
(
zip
(
cu_seqlens
[:
-
1
],
cu_seqlens
[
1
:])):
x
[
s
:
e
].
copy_
(
x_padded
[
i
,
:,
:
e
-
s
].
transpose
(
0
,
1
))
return
x
def
spda
(
self
,
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
None
,
sm_scale
=
None
):
"""For CPU, V100 or other older GPUs.
NOTE: torch SPDA supports nested tensor,
but seems extremely slow. Choose to pad instead.
"""
assert
(
cu_seqlens_q
is
None
or
(
cu_seqlens_q
==
cu_seqlens_k
).
all
()),
"Can only handle prompt with SPDA."
assert
q
.
size
(
0
)
==
k
.
size
(
0
),
"can only handle prompt with SPDA."
assert
q
.
size
(
1
)
%
k
.
size
(
1
)
==
0
q_k_ratio
=
q
.
size
(
1
)
//
k
.
size
(
1
)
sm_scale
=
sm_scale
or
1.0
/
math
.
sqrt
(
q
.
size
(
-
1
))
cu_seqlens
=
cu_seqlens_k
.
cpu
()
maxlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
()
if
(
self
.
dense_attn_mask
.
dtype
!=
q
.
dtype
or
self
.
dense_attn_mask
.
device
!=
q
.
device
):
_
,
_
,
self
.
dense_attn_mask
=
self
.
get_attn_pattern
(
q
.
dtype
,
q
.
device
)
attn_mask
=
self
.
dense_attn_mask
[
None
,
:,
:
maxlen
,
:
maxlen
]
q2
=
self
.
transpose_and_pad
(
q
,
cu_seqlens
,
maxlen
,
1
)
k2
,
v2
=
[
self
.
transpose_and_pad
(
x
,
cu_seqlens
,
maxlen
,
q_k_ratio
)
for
x
in
[
k
,
v
]
]
spda_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q2
,
k2
,
v2
,
attn_mask
=
attn_mask
,
scale
=
sm_scale
)
return
self
.
transpose_and_unpad
(
spda_output
,
cu_seqlens
)
def
forward
(
self
,
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
None
,
sm_scale
=
None
):
"""Dispatch to `varlen_attn` (Ampere or newer) or
`self.spda`(cpu, Volta, Turing or older)based on
the type of device used and cuda compute capability.
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify
is when q is a mix of prefilling
and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
"""
assert
k
.
dim
()
==
3
if
self
.
use_spda
:
return
self
.
spda
(
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
sm_scale
=
sm_scale
,
)
return
self
.
varlen_attn
(
q
,
k
,
v
,
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
sm_scale
=
sm_scale
)
\ No newline at end of file
vllm/attention/ops/blocksparse_attention/utils.py
0 → 100644
View file @
b9e12416
# Helper functions for 3D sparse pattern
# These function are not optimized and very inefficient.
# Avoid calling them too frequent or use a cache mechanism.
from
functools
import
lru_cache
import
torch
import
triton
from
scipy
import
sparse
def
dense_to_crow_col
(
x
:
torch
.
Tensor
):
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
NOTE: col_indices padded -1
"""
device
=
x
.
device
pad
=
-
1
dim
=
x
.
dim
()
assert
x
.
dim
()
in
(
2
,
3
)
if
x
.
dim
()
==
2
:
x
=
x
[
None
]
x
=
[
sparse
.
csr_matrix
(
xi
.
bool
().
cpu
().
numpy
())
for
xi
in
x
]
crows
=
torch
.
vstack
([
torch
.
from_numpy
(
xi
.
indptr
)
for
xi
in
x
])
cols
=
[
torch
.
from_numpy
(
xi
.
indices
)
for
xi
in
x
]
max_cols
=
max
(
len
(
xi
)
for
xi
in
cols
)
cols
=
[
torch
.
cat
([
xi
,
pad
+
xi
.
new_zeros
(
max_cols
-
xi
.
shape
[
0
])])
for
xi
in
cols
]
cols
=
torch
.
vstack
(
cols
)
if
dim
==
2
:
crows
=
crows
[
0
]
cols
=
cols
[
0
]
return
crows
.
to
(
device
),
cols
.
to
(
device
)
def
crow_col_to_dense
(
crows
:
torch
.
Tensor
,
cols
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float16
):
dim
=
crows
.
dim
()
if
dim
==
1
:
crows
=
crows
[
None
]
cols
=
cols
[
None
]
device
=
crows
.
device
crows
,
cols
=
crows
.
cpu
(),
cols
.
cpu
()
# faster in cpu
shape
=
(
crows
.
shape
[
0
],
crows
.
shape
[
1
]
-
1
,
cols
.
max
()
+
1
)
x
=
torch
.
zeros
(
shape
,
dtype
=
dtype
)
for
i
in
range
(
shape
[
0
]):
for
j
in
range
(
shape
[
1
]):
x
[
i
,
j
,
cols
[
i
,
crows
[
i
,
j
]:
crows
[
i
,
j
+
1
]]]
=
1
if
dim
==
1
:
x
=
x
[
0
]
return
x
.
to
(
device
)
def
dense_to_ccol_row
(
x
:
torch
.
Tensor
):
"""Similar, but to CSC format"""
x
=
x
.
transpose
(
-
2
,
-
1
)
return
dense_to_crow_col
(
x
)
def
ccol_row_to_dense
(
ccol
:
torch
.
Tensor
,
rows
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float16
):
return
crow_col_to_dense
(
ccol
,
rows
,
dtype
).
permute
(
0
,
2
,
1
).
contiguous
()
def
_get_sparse_attn_mask_homo_head
(
q_len
:
int
,
max_seqlen
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
block_size
:
int
=
128
,
local_blocks
:
int
=
4
,
vert_stride
:
int
=
4
,
return_dense
:
bool
=
False
,
):
"""
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
otherwise, None
"""
with
torch
.
no_grad
():
num_blocks
=
triton
.
cdiv
(
max_seqlen
,
block_size
)
q_pos
=
torch
.
arange
(
num_blocks
)[:,
None
]
k_pos
=
torch
.
arange
(
num_blocks
)[
None
]
mask_vert_strided
=
(
torch
.
arange
(
num_blocks
)
+
1
)
%
vert_stride
==
0
block_mask_dense
=
(((
q_pos
>=
k_pos
)
&
((
q_pos
-
k_pos
<
local_blocks
)
|
mask_vert_strided
)).
to
(
device
).
to
(
dtype
))
num_blocks_q
=
triton
.
cdiv
(
q_len
,
block_size
)
block_mask_dense_output
=
(
dense_to_crow_col
(
block_mask_dense
[
-
num_blocks_q
:].
contiguous
()))
if
return_dense
:
mask_dense
=
torch
.
kron
(
block_mask_dense
,
block_mask_dense
.
new_ones
((
block_size
,
block_size
)),
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
max_seqlen
,
max_seqlen
)).
type_as
(
mask_dense
)[
-
q_len
:]
mask_dense
=
mask_dense
[
-
q_len
:,
:
max_seqlen
]
*
causal_mask
return
(
block_mask_dense_output
,
block_mask_dense
,
mask_dense
,
)
else
:
return
(
block_mask_dense_output
,
block_mask_dense
,
None
,
)
def
binary_mask_to_bias
(
mask_dense
:
torch
.
Tensor
):
mask_dense
=
1
-
mask_dense
mask_dense
.
masked_fill_
(
mask_dense
.
bool
(),
-
torch
.
inf
)
return
mask_dense
def
get_head_sliding_step
(
n_heads
:
int
,
vert_stride
:
int
,
homo_head
:
bool
=
False
):
if
homo_head
:
return
0
return
max
(
1
,
int
(
vert_stride
/
n_heads
))
@
lru_cache
def
get_sparse_attn_mask
(
n_heads
:
int
,
q_len
:
int
,
max_seqlen
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
block_size
:
int
=
64
,
local_blocks
:
int
=
4
,
vert_stride
:
int
=
4
,
homo_head
:
bool
=
True
,
return_dense
:
bool
=
False
,
dense_mask_type
:
str
=
"binary"
,
):
"""
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None
"""
assert
dense_mask_type
in
(
"binary"
,
"bias"
)
if
homo_head
:
with
torch
.
no_grad
():
(
crow
,
col
),
block_mask_dense
,
mask_dense
=
(
_get_sparse_attn_mask_homo_head
(
q_len
,
max_seqlen
,
dtype
,
device
,
block_size
,
local_blocks
,
vert_stride
,
return_dense
,
))
crow
=
crow
[
None
].
expand
(
n_heads
,
crow
.
shape
[
0
])
col
=
col
[
None
].
expand
(
n_heads
,
col
.
shape
[
0
])
if
return_dense
:
mask_dense
=
mask_dense
[
None
].
expand
(
n_heads
,
*
mask_dense
.
shape
)
if
dense_mask_type
==
"bias"
:
mask_dense
=
binary_mask_to_bias
(
mask_dense
)
return
(
crow
,
col
),
block_mask_dense
,
mask_dense
with
torch
.
no_grad
():
num_blocks
=
triton
.
cdiv
(
max_seqlen
,
block_size
)
q_pos
=
torch
.
arange
(
num_blocks
)[
None
,
:,
None
]
k_pos
=
torch
.
arange
(
num_blocks
)[
None
,
None
]
head_sliding_step
=
get_head_sliding_step
(
n_heads
,
vert_stride
)
mask_vert_strided
=
[
(
torch
.
arange
(
num_blocks
)
+
h
*
head_sliding_step
+
1
)
%
vert_stride
==
0
for
h
in
range
(
n_heads
)
]
mask_vert_strided
=
torch
.
vstack
(
mask_vert_strided
).
unsqueeze
(
1
)
block_mask_dense
=
(((
q_pos
>=
k_pos
)
&
((
q_pos
-
k_pos
<
local_blocks
)
|
mask_vert_strided
)).
to
(
device
).
to
(
dtype
))
num_blocks_q
=
triton
.
cdiv
(
q_len
,
block_size
)
block_mask_dense_output
=
block_mask_dense
[:,
-
num_blocks_q
:]
if
return_dense
:
mask_dense
=
torch
.
kron
(
block_mask_dense
,
block_mask_dense
.
new_ones
((
block_size
,
block_size
)),
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
max_seqlen
,
max_seqlen
)).
type_as
(
mask_dense
)[
-
q_len
:]
mask_dense
=
mask_dense
[...,
-
q_len
:,
:
max_seqlen
]
*
causal_mask
[
None
]
if
dense_mask_type
==
"bias"
:
mask_dense
=
binary_mask_to_bias
(
mask_dense
)
return
(
dense_to_crow_col
(
block_mask_dense_output
),
block_mask_dense
,
mask_dense
,
)
else
:
return
(
dense_to_crow_col
(
block_mask_dense_output
),
block_mask_dense
,
None
,
)
vllm/attention/ops/paged_attn.py
View file @
b9e12416
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -16,8 +16,8 @@ class PagedAttentionMetadata:
...
@@ -16,8 +16,8 @@ class PagedAttentionMetadata:
# (batch_size,). The length of sequences (entire tokens seen so far) per
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
# sequence.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Maximum sequence length in the batch.
# Maximum sequence length in the batch.
0 if it is prefill-only batch.
max_seq_len
:
Optional
[
int
]
max_
decode_
seq_len
:
int
# (batch_size, max_blocks_per_seq).
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
...
@@ -31,7 +31,7 @@ class PagedAttention:
...
@@ -31,7 +31,7 @@ class PagedAttention:
@
staticmethod
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
256
]
return
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
@@ -91,9 +91,21 @@ class PagedAttention:
...
@@ -91,9 +91,21 @@ class PagedAttention:
scale
:
float
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
:
float
,
kv_scale
:
float
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
if
blocksparse_vert_stride
is
not
None
and
blocksparse_vert_stride
>
1
:
# use blocksparse paged attention
block_size
=
value_cache
.
size
(
-
1
)
assert
(
blocksparse_block_size
>
0
and
blocksparse_block_size
%
block_size
==
0
),
\
(
f
"
{
blocksparse_block_size
=
}
needs to be a multiple of"
f
"
{
block_size
=
}
used in block_tables."
)
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE
-
1
)
//
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE
-
1
)
//
...
@@ -107,6 +119,7 @@ class PagedAttention:
...
@@ -107,6 +119,7 @@ class PagedAttention:
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
(
max_seq_len
<=
8192
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
if
use_v1
:
# Run PagedAttention V1.
# Run PagedAttention V1.
ops
.
paged_attention_v1
(
ops
.
paged_attention_v1
(
...
@@ -123,6 +136,11 @@ class PagedAttention:
...
@@ -123,6 +136,11 @@ class PagedAttention:
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
)
else
:
else
:
# Run PagedAttention V2.
# Run PagedAttention V2.
...
@@ -155,6 +173,11 @@ class PagedAttention:
...
@@ -155,6 +173,11 @@ class PagedAttention:
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
)
return
output
return
output
...
@@ -166,7 +189,7 @@ class PagedAttention:
...
@@ -166,7 +189,7 @@ class PagedAttention:
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
sub
query_start_loc
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
seq_lens_tensor
:
torch
.
Tensor
,
seq_lens_tensor
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_query_len
:
int
,
max_query_len
:
int
,
...
@@ -182,8 +205,8 @@ class PagedAttention:
...
@@ -182,8 +205,8 @@ class PagedAttention:
key_cache
,
key_cache
,
value_cache
,
value_cache
,
block_tables
,
block_tables
,
#
sub
query_start_loc is (batch_size + 1,)
# query_start_loc is (batch_size + 1,)
sub
query_start_loc
[:
-
1
],
query_start_loc
[:
-
1
],
seq_lens_tensor
,
seq_lens_tensor
,
context_lens
,
context_lens
,
max_query_len
,
max_query_len
,
...
@@ -196,7 +219,7 @@ class PagedAttention:
...
@@ -196,7 +219,7 @@ class PagedAttention:
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
...
@@ -209,7 +232,7 @@ class PagedAttention:
...
@@ -209,7 +232,7 @@ class PagedAttention:
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]]
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
...
...
vllm/attention/ops/prefix_prefill.py
View file @
b9e12416
...
@@ -472,7 +472,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -472,7 +472,8 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_bl
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
# head size
BLOCK_DMODEL_PADDED
:
tl
.
constexpr
,
# head size padded to a power of 2
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
# attn_bias[]
# attn_bias[]
...
@@ -493,21 +494,24 @@ if triton.__version__ >= "2.1.0":
...
@@ -493,21 +494,24 @@ if triton.__version__ >= "2.1.0":
# initialize offsets
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
_PADDED
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
q
=
tl
.
load
(
dim_mask
=
tl
.
where
(
Q
+
off_q
,
tl
.
arange
(
0
,
BLOCK_DMODEL_PADDED
)
<
BLOCK_DMODEL
,
1
,
0
).
to
(
tl
.
int1
)
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
),
other
=
0.0
)
# # initialize pointer to m and l
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
_PADDED
],
dtype
=
tl
.
float32
)
alibi_slope
=
tl
.
load
(
Alibi_slopes
+
cur_head
)
alibi_slope
=
tl
.
load
(
Alibi_slopes
+
cur_head
)
alibi_start_q
=
tl
.
arange
(
alibi_start_q
=
tl
.
arange
(
...
@@ -532,8 +536,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -532,8 +536,9 @@ if triton.__version__ >= "2.1.0":
offs_d
[
None
,
:]
*
stride_v_cache_d
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
mask
=
dim_mask
[:,
None
]
&
other
=
0.0
)
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
# [D,N]
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
...
@@ -567,7 +572,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -567,7 +572,8 @@ if triton.__version__ >= "2.1.0":
acc
=
acc
*
acc_scale
[:,
None
]
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
,
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
...
@@ -600,8 +606,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -600,8 +606,9 @@ if triton.__version__ >= "2.1.0":
# -- compute qk ----
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
mask
=
dim_mask
[:,
None
]
&
cur_batch_seq_len
-
cur_batch_ctx_len
,
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
-
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
...
@@ -637,8 +644,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -637,8 +644,9 @@ if triton.__version__ >= "2.1.0":
# update acc
# update acc
v
=
tl
.
load
(
v_ptrs
+
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
mask
=
dim_mask
[
None
,
:]
&
cur_batch_seq_len
-
cur_batch_ctx_len
,
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
-
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
...
@@ -656,7 +664,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -656,7 +664,8 @@ if triton.__version__ >= "2.1.0":
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
tl
.
store
(
out_ptrs
,
acc
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
)
mask
=
dim_mask
[
None
,
:]
&
(
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
))
return
return
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -688,9 +697,12 @@ if triton.__version__ >= "2.1.0":
...
@@ -688,9 +697,12 @@ if triton.__version__ >= "2.1.0":
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
# 0 means "disable"
if
sliding_window
is
None
or
sliding_window
<=
0
:
sliding_window
=
0
num_warps
=
8
if
Lk
<=
64
else
8
num_warps
=
8
if
Lk
<=
64
else
8
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
assert
Lk
==
Lk_padded
_fwd_kernel_alibi
[
grid
](
_fwd_kernel_alibi
[
grid
](
q
,
q
,
k
,
k
,
...
@@ -735,6 +747,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -735,6 +747,7 @@ if triton.__version__ >= "2.1.0":
num_queries_per_kv
=
num_queries_per_kv
,
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
...
@@ -785,7 +798,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -785,7 +798,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
SLIDING_WINDOW
=
sliding_window
if
sliding_window
is
not
None
else
0
,
SLIDING_WINDOW
=
sliding_window
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
...
...
vllm/attention/ops/triton_flash_attention.py
View file @
b9e12416
...
@@ -239,6 +239,16 @@ def _attn_fwd_inner(
...
@@ -239,6 +239,16 @@ def _attn_fwd_inner(
num_stages
=
1
,
num_stages
=
1
,
num_warps
=
8
,
num_warps
=
8
,
),
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
1
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
4
,
),
triton
.
Config
(
triton
.
Config
(
{
{
"BLOCK_M"
:
128
,
"BLOCK_M"
:
128
,
...
...
vllm/attention/selector.py
View file @
b9e12416
import
enum
import
enum
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Type
from
typing
import
Optional
,
Type
import
torch
import
torch
...
@@ -21,14 +21,33 @@ class _Backend(enum.Enum):
...
@@ -21,14 +21,33 @@ class _Backend(enum.Enum):
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_attn_backend
(
dtype
:
torch
.
dtype
)
->
Type
[
AttentionBackend
]:
def
get_attn_backend
(
backend
=
_which_attn_to_use
(
dtype
)
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
is_blocksparse
:
bool
=
False
,
)
->
Type
[
AttentionBackend
]:
if
is_blocksparse
:
logger
.
info
(
"Using BlocksparseFlashAttention backend."
)
from
vllm.attention.backends.blocksparse_attn
import
(
BlocksparseFlashAttentionBackend
)
return
BlocksparseFlashAttentionBackend
"""Determine which attention backend to use and only import
the selected backend module.
"""
backend
=
which_attn_to_use
(
num_heads
,
head_size
,
num_kv_heads
,
sliding_window
,
dtype
,
kv_cache_dtype
,
block_size
)
if
backend
==
_Backend
.
FLASH_ATTN
:
if
backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
"Using FlashAttention-2 backend."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
FlashAttentionBackend
)
return
FlashAttentionBackend
return
FlashAttentionBackend
el
if
backend
==
_Backend
.
XFORMERS
:
if
backend
==
_Backend
.
XFORMERS
:
logger
.
info
(
"Using XFormers backend."
)
logger
.
info
(
"Using XFormers backend."
)
from
vllm.attention.backends.xformers
import
(
# noqa: F401
from
vllm.attention.backends.xformers
import
(
# noqa: F401
XFormersBackend
)
XFormersBackend
)
...
@@ -44,48 +63,102 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
...
@@ -44,48 +63,102 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
return
TorchSDPABackend
return
TorchSDPABackend
elif
backend
==
_Backend
.
FLASHINFER
:
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
(
"Eager mode is enforced for the Flashinfer backend. "
)
logger
.
warning
(
"Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set."
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
return
FlashInferBackend
else
:
else
:
raise
ValueError
(
"Invalid attention backend."
)
raise
ValueError
(
"Invalid attention backend."
)
def
_which_attn_to_use
(
dtype
:
torch
.
dtype
)
->
_Backend
:
def
which_attn_to_use
(
num_heads
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
)
->
_Backend
:
"""Returns which flash attention backend to use."""
"""Returns which flash attention backend to use."""
# Default case.
selected_backend
=
_Backend
.
FLASH_ATTN
# Check the environment variable and override if specified
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
backend_members
=
_Backend
.
__members__
if
backend_by_env_var
not
in
backend_members
:
raise
ValueError
(
f
"Invalid attention backend '
{
backend_by_env_var
}
'. "
f
"Available backends:
{
', '
.
join
(
backend_members
)
}
"
"(case-sensitive)."
)
selected_backend
=
_Backend
[
backend_by_env_var
]
if
is_cpu
():
if
is_cpu
():
if
selected_backend
!=
_Backend
.
TORCH_SDPA
:
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
return
_Backend
.
TORCH_SDPA
if
is_hip
():
if
is_hip
():
# AMD GPUs.
# AMD GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
# not Instinct series GPUs.
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
logger
.
info
(
"flash_atten is not supported on NAVI GPUs."
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
return
_Backend
.
ROCM_FLASH
return
_Backend
.
ROCM_FLASH
# NVIDIA GPUs.
# FlashAttn in NVIDIA GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
# Volta and Turing NVIDIA GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
logger
.
info
(
"Cannot use FlashAttention-2 backend for Volta and Turing "
# Volta and Turing NVIDIA GPUs.
"GPUs."
)
logger
.
info
(
return
_Backend
.
XFORMERS
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs."
)
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
selected_backend
=
_Backend
.
XFORMERS
logger
.
info
(
"Cannot use FlashAttention-2 backend for dtype other than "
elif
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
"torch.float16 or torch.bfloat16."
)
logger
.
info
(
return
_Backend
.
XFORMERS
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16."
)
try
:
selected_backend
=
_Backend
.
XFORMERS
import
flash_attn
# noqa: F401
elif
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
):
except
ImportError
:
logger
.
info
(
logger
.
info
(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
"Cannot use FlashAttention-2 backend because the flash_attn "
selected_backend
=
_Backend
.
XFORMERS
"package is not found. Please install it for better performance."
)
elif
block_size
%
16
!=
0
:
return
_Backend
.
XFORMERS
logger
.
info
(
"Cannot use FlashAttention-2 backend for block size not "
backend_by_env_var
=
envs
.
VLLM_ATTENTION_BACKEND
"divisible by 16."
)
if
backend_by_env_var
is
not
None
:
selected_backend
=
_Backend
.
XFORMERS
return
_Backend
[
backend_by_env_var
]
elif
sliding_window
is
not
None
:
logger
.
info
(
"Cannot use FlashAttention-2 backend due to sliding window."
)
selected_backend
=
_Backend
.
XFORMERS
# Default case.
# FlashAttn is valid for the model, checking if the package is installed.
return
_Backend
.
FLASH_ATTN
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
try
:
import
vllm_flash_attn
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
supported_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
supported_sizes
:
logger
.
info
(
"Cannot use FlashAttention-2 backend for head size %d."
,
head_size
)
selected_backend
=
_Backend
.
XFORMERS
except
ImportError
:
logger
.
info
(
"Cannot use FlashAttention-2 backend because the "
"vllm_flash_attn package is not found. "
"`pip install vllm-flash-attn` for better performance."
)
selected_backend
=
_Backend
.
XFORMERS
return
selected_backend
Prev
1
…
9
10
11
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment