Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
448
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
909 additions
and
169 deletions
+909
-169
vllm/assets/base.py
vllm/assets/base.py
+29
-1
vllm/assets/image.py
vllm/assets/image.py
+16
-25
vllm/attention/__init__.py
vllm/attention/__init__.py
+4
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+53
-1
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+6
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+197
-42
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+200
-31
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+5
-0
vllm/attention/backends/openvino.py
vllm/attention/backends/openvino.py
+6
-1
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+5
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+9
-1
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+5
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+128
-23
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+9
-1
vllm/attention/layer.py
vllm/attention/layer.py
+1
-1
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+1
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+6
-0
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+73
-24
vllm/attention/selector.py
vllm/attention/selector.py
+111
-13
vllm/block.py
vllm/block.py
+45
-3
No files found.
Too many changes to show.
To preserve performance only
448 of 448+
files are displayed.
Plain diff
Email patch
vllm/assets/base.py
View file @
af7f4372
from
functools
import
lru_cache
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.connections
import
global_http_connection
from
vllm.envs
import
VLLM_IMAGE_FETCH_TIMEOUT
vLLM_S3_BUCKET_URL
=
"https://vllm-public-assets.s3.us-west-2.amazonaws.com"
def
get_cache_dir
():
def
get_cache_dir
()
->
Path
:
"""Get the path to the cache for storing downloaded assets."""
"""Get the path to the cache for storing downloaded assets."""
path
=
Path
(
envs
.
VLLM_ASSETS_CACHE
)
path
=
Path
(
envs
.
VLLM_ASSETS_CACHE
)
path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
return
path
return
path
@
lru_cache
def
get_vllm_public_assets
(
filename
:
str
,
s3_prefix
:
Optional
[
str
]
=
None
)
->
Path
:
"""
Download an asset file from ``s3://vllm-public-assets``
and return the path to the downloaded file.
"""
asset_directory
=
get_cache_dir
()
/
"vllm_public_assets"
asset_directory
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
asset_path
=
asset_directory
/
filename
if
not
asset_path
.
exists
():
if
s3_prefix
is
not
None
:
filename
=
s3_prefix
+
"/"
+
filename
global_http_connection
.
download_file
(
f
"
{
vLLM_S3_BUCKET_URL
}
/
{
filename
}
"
,
asset_path
,
timeout
=
VLLM_IMAGE_FETCH_TIMEOUT
)
return
asset_path
vllm/assets/image.py
View file @
af7f4372
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
Literal
from
typing
import
Literal
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
vllm.connections
import
global_http_connection
from
vllm.assets.base
import
get_vllm_public_assets
from
vllm.envs
import
VLLM_IMAGE_FETCH_TIMEOUT
from
.base
import
get_cache_dir
VLM_IMAGES_DIR
=
"vision_model_images"
@
lru_cache
def
get_air_example_data_2_asset
(
filename
:
str
)
->
Image
.
Image
:
"""
Download and open an image from
``s3://air-example-data-2/vllm_opensource_llava/``.
"""
image_directory
=
get_cache_dir
()
/
"air-example-data-2"
image_directory
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
image_path
=
image_directory
/
filename
if
not
image_path
.
exists
():
base_url
=
"https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"
global_http_connection
.
download_file
(
f
"
{
base_url
}
/
{
filename
}
"
,
image_path
,
timeout
=
VLLM_IMAGE_FETCH_TIMEOUT
)
return
Image
.
open
(
image_path
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -36,4 +15,16 @@ class ImageAsset:
...
@@ -36,4 +15,16 @@ class ImageAsset:
@
property
@
property
def
pil_image
(
self
)
->
Image
.
Image
:
def
pil_image
(
self
)
->
Image
.
Image
:
return
get_air_example_data_2_asset
(
f
"
{
self
.
name
}
.jpg"
)
image_path
=
get_vllm_public_assets
(
filename
=
f
"
{
self
.
name
}
.jpg"
,
s3_prefix
=
VLM_IMAGES_DIR
)
return
Image
.
open
(
image_path
)
@
property
def
image_embeds
(
self
)
->
torch
.
Tensor
:
"""
Image embeddings, only used for testing purposes with llava 1.5.
"""
image_path
=
get_vllm_public_assets
(
filename
=
f
"
{
self
.
name
}
.pt"
,
s3_prefix
=
VLM_IMAGES_DIR
)
return
torch
.
load
(
image_path
)
vllm/attention/__init__.py
View file @
af7f4372
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
)
AttentionMetadataBuilder
,
AttentionState
,
AttentionType
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
...
@@ -8,7 +9,9 @@ __all__ = [
...
@@ -8,7 +9,9 @@ __all__ = [
"Attention"
,
"Attention"
,
"AttentionBackend"
,
"AttentionBackend"
,
"AttentionMetadata"
,
"AttentionMetadata"
,
"AttentionType"
,
"AttentionMetadataBuilder"
,
"AttentionMetadataBuilder"
,
"Attention"
,
"Attention"
,
"AttentionState"
,
"get_attn_backend"
,
"get_attn_backend"
,
]
]
vllm/attention/backends/abstract.py
View file @
af7f4372
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
...
@@ -7,7 +8,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
...
@@ -7,7 +8,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import
torch
import
torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerInputBuilderBase
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
)
class
AttentionType
(
Enum
):
class
AttentionType
(
Enum
):
...
@@ -34,6 +37,11 @@ class AttentionBackend(ABC):
...
@@ -34,6 +37,11 @@ class AttentionBackend(ABC):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
get_state_cls
()
->
Type
[
"AttentionState"
]:
raise
NotImplementedError
@
classmethod
@
classmethod
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
make_metadata
(
cls
,
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
return
cls
.
get_metadata_cls
()(
*
args
,
**
kwargs
)
...
@@ -75,6 +83,9 @@ class AttentionBackend(ABC):
...
@@ -75,6 +83,9 @@ class AttentionBackend(ABC):
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
):
raise
NotImplementedError
@
dataclass
@
dataclass
class
AttentionMetadata
:
class
AttentionMetadata
:
...
@@ -123,6 +134,47 @@ class AttentionMetadata:
...
@@ -123,6 +134,47 @@ class AttentionMetadata:
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
T
=
TypeVar
(
"T"
,
bound
=
AttentionMetadata
)
class
AttentionState
(
ABC
,
Generic
[
T
]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""
@
abstractmethod
def
__init__
(
self
,
runner
:
"ModelRunnerBase"
):
...
@
abstractmethod
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
"""Context manager used when capturing CUDA graphs."""
yield
@
abstractmethod
def
graph_clone
(
self
,
batch_size
:
int
)
->
"AttentionState[T]"
:
"""Clone attention state to save in CUDA graph metadata."""
...
@
abstractmethod
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
)
->
T
:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@
abstractmethod
def
get_graph_input_buffers
(
self
,
attn_metadata
:
T
)
->
Dict
[
str
,
Any
]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@
abstractmethod
def
prepare_graph_input_buffers
(
self
,
input_buffers
:
Dict
[
str
,
Any
],
attn_metadata
:
T
)
->
None
:
"""In-place modify input buffers dict for CUDA graph replay."""
...
@
abstractmethod
def
begin_forward
(
self
,
model_input
:
"ModelRunnerInputBase"
)
->
None
:
"""Prepare state for forward pass."""
...
class
AttentionMetadataBuilder
(
ABC
,
Generic
[
T
]):
class
AttentionMetadataBuilder
(
ABC
,
Generic
[
T
]):
"""Abstract class for attention metadata builders."""
"""Abstract class for attention metadata builders."""
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
af7f4372
...
@@ -5,7 +5,8 @@ import torch
...
@@ -5,7 +5,8 @@ import torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.blocksparse_attention.interface
import
(
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
LocalStridedBlockSparseAttn
,
get_head_sliding_step
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
...
@@ -98,6 +99,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
...
@@ -98,6 +99,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"BlocksparseFlashAttentionMetadataBuilder"
]:
return
BlocksparseFlashAttentionMetadataBuilder
return
BlocksparseFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/flash_attn.py
View file @
af7f4372
...
@@ -3,21 +3,123 @@ from dataclasses import dataclass
...
@@ -3,21 +3,123 @@ from dataclasses import dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm_flash_attn
import
flash_attn_varlen_func
as
_flash_attn_varlen_func
from
vllm_flash_attn
import
flash_attn_with_kvcache
as
_flash_attn_with_kvcache
@
torch
.
library
.
custom_op
(
"vllm::flash_attn_varlen_func"
,
mutates_args
=
[])
def
flash_attn_varlen_func
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
softcap
:
float
=
0.0
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# custom op does not support tuple input
real_window_size
:
Tuple
[
int
,
int
]
if
window_size
is
None
:
real_window_size
=
(
-
1
,
-
1
)
else
:
assert
len
(
window_size
)
==
2
real_window_size
=
(
window_size
[
0
],
window_size
[
1
])
return
_flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
real_window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
block_table
=
block_table
,
)
@
flash_attn_varlen_func
.
register_fake
# type: ignore
def
_
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
softcap
:
float
=
0.0
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
@
torch
.
library
.
custom_op
(
"vllm::flash_attn_with_kvcache"
,
mutates_args
=
[])
def
flash_attn_with_kvcache
(
decode_query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
:
float
=
0.0
,
)
->
torch
.
Tensor
:
return
_flash_attn_with_kvcache
(
decode_query
,
key_cache
,
value_cache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
alibi_slopes
,
softcap
=
softcap
,
)
@
flash_attn_with_kvcache
.
register_fake
# type: ignore
def
_
(
decode_query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
:
float
=
0.0
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
decode_query
)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -41,6 +143,10 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -41,6 +143,10 @@ class FlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
return
FlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -196,6 +302,51 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -196,6 +302,51 @@ class FlashAttentionMetadata(AttentionMetadata):
)
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
):
"""
Update metadata in-place to advance one decode step.
"""
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
assert
self
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
class
FlashAttentionMetadataBuilder
(
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
...
@@ -259,7 +410,11 @@ class FlashAttentionMetadataBuilder(
...
@@ -259,7 +410,11 @@ class FlashAttentionMetadataBuilder(
block_table
=
block_tables
[
seq_id
]
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
and
block_tables
is
not
None
):
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
# Compute slot mapping.
...
@@ -310,7 +465,8 @@ class FlashAttentionMetadataBuilder(
...
@@ -310,7 +465,8 @@ class FlashAttentionMetadataBuilder(
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
=
device
,
non_blocking
=
True
)
else
:
else
:
block_tables
=
make_tensor_with_pad
(
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
self
.
block_tables
,
...
@@ -320,15 +476,15 @@ class FlashAttentionMetadataBuilder(
...
@@ -320,15 +476,15 @@ class FlashAttentionMetadataBuilder(
)
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
assert
device
is
not
None
dtype
=
torch
.
int
,
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
=
device
)
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
seq_lens_tensor
=
async_
tensor
_h2d
(
seq_lens
,
torch
.
int
,
device
,
dtype
=
torch
.
int
,
self
.
runner
.
pin_memory
)
device
=
device
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
self
.
runner
.
pin_memory
)
dtype
=
torch
.
long
,
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
=
device
)
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -344,10 +500,6 @@ class FlashAttentionMetadataBuilder(
...
@@ -344,10 +500,6 @@ class FlashAttentionMetadataBuilder(
dtype
=
query_start_loc
.
dtype
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
FlashAttentionMetadata
(
return
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
...
@@ -516,7 +668,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -516,7 +668,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# normal attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
out
=
flash_attn_varlen_func
(
out
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -536,33 +688,36 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -536,33 +688,36 @@ class FlashAttentionImpl(AttentionImpl):
# prefix-enabled attention
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
output
[:
q
=
query
,
num_prefill_tokens
]
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
# noqa
k
=
key_cache
,
q
=
query
,
v
=
value_cache
,
k
=
key_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
v
=
value_cache
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
max_seqlen_k
=
max_seq_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
)
).
squeeze
(
1
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
).
squeeze
(
1
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/flashinfer.py
View file @
af7f4372
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
import
vllm.attention.backends.flash_attn
# noqa
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
except
ImportError
:
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
import
torch
import
torch
...
@@ -16,12 +21,13 @@ from vllm import _custom_ops as ops
...
@@ -16,12 +21,13 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionState
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.utils
import
get_kv_cache_torch_dtype
,
make_tensor_with_pad
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
...
@@ -45,6 +51,10 @@ class FlashInferBackend(AttentionBackend):
...
@@ -45,6 +51,10 @@ class FlashInferBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"FlashInferMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"FlashInferMetadataBuilder"
]:
return
FlashInferMetadataBuilder
return
FlashInferMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"FlashInferState"
]:
return
FlashInferState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -74,6 +84,162 @@ class FlashInferBackend(AttentionBackend):
...
@@ -74,6 +84,162 @@ class FlashInferBackend(AttentionBackend):
return
[
64
,
128
,
256
]
return
[
64
,
128
,
256
]
class
FlashInferState
(
AttentionState
):
def
__init__
(
self
,
runner
):
self
.
runner
=
runner
self
.
_is_graph_capturing
=
False
self
.
_workspace_buffer
=
None
self
.
_decode_wrapper
=
None
self
.
_prefill_wrapper
=
None
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
self
.
_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
if
self
.
_decode_wrapper
is
None
:
num_qo_heads
=
(
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
))
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
)
use_tensor_cores
=
(
num_qo_heads
//
num_kv_heads
)
not
in
\
(
1
,
2
,
4
,
8
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
self
.
_is_graph_capturing
=
True
self
.
_graph_decode_wrapper
=
None
self
.
_graph_slot_mapping
=
torch
.
full
((
max_batch_size
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
self
.
_graph_seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_block_tables
=
torch
.
from_numpy
(
self
.
runner
.
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
)
self
.
_graph_decode_workspace_buffer
=
self
.
_get_workspace_buffer
()
self
.
_graph_indices_buffer
=
torch
.
empty
(
max_batch_size
*
self
.
runner
.
cache_config
.
num_gpu_blocks
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_indptr_buffer
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_last_page_len_buffer
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
yield
self
.
_is_graph_capturing
=
False
del
self
.
_graph_slot_mapping
del
self
.
_graph_seq_lens
del
self
.
_graph_block_tables
del
self
.
_graph_decode_workspace_buffer
del
self
.
_graph_indices_buffer
del
self
.
_graph_indptr_buffer
del
self
.
_graph_last_page_len_buffer
del
self
.
_graph_decode_wrapper
def
graph_clone
(
self
,
batch_size
:
int
):
assert
self
.
_is_graph_capturing
state
=
self
.
__class__
(
self
.
runner
)
state
.
_workspace_buffer
=
self
.
_graph_decode_workspace_buffer
state
.
_decode_wrapper
=
self
.
_graph_decode_wrapper
state
.
_prefill_wrapper
=
self
.
_get_prefill_wrapper
()
return
state
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
):
assert
self
.
_is_graph_capturing
_indptr_buffer
=
self
.
_graph_indptr_buffer
[:
batch_size
+
1
]
_last_page_len_buffer
=
self
.
_graph_last_page_len_buffer
[:
batch_size
]
num_qo_heads
=
(
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
))
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
)
use_tensor_cores
=
(
num_qo_heads
//
num_kv_heads
)
not
in
\
(
1
,
2
,
4
,
8
)
self
.
_graph_decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
use_tensor_cores
)
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
paged_kv_indptr_tensor_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
paged_kv_indices_tensor_host
=
torch
.
arange
(
0
,
batch_size
,
dtype
=
torch
.
int32
)
paged_kv_last_page_len_tensor_host
=
torch
.
full
((
batch_size
,
),
self
.
runner
.
block_size
,
dtype
=
torch
.
int32
)
query_start_loc_host
=
torch
.
arange
(
0
,
batch_size
+
1
,
dtype
=
torch
.
int32
)
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
max_prefill_seq_len
=
0
,
block_tables
=
self
.
_graph_block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor_host
,
paged_kv_indices
=
paged_kv_indices_tensor_host
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor_host
,
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
runner
.
block_size
,
seq_start_loc
=
None
,
query_start_loc
=
query_start_loc_host
,
device
=
self
.
runner
.
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
True
,
decode_wrapper
=
self
.
_graph_decode_wrapper
,
prefill_wrapper
=
None
)
attn_metadata
.
begin_forward
()
return
attn_metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
):
return
{
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
}
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
):
return
def
begin_forward
(
self
,
model_input
):
assert
not
self
.
_is_graph_capturing
state
=
self
if
model_input
.
attn_metadata
.
use_cuda_graph
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
state
=
(
self
.
runner
.
graph_runners
[
model_input
.
virtual_engine
]
[
batch_size
].
attn_state
)
model_input
.
attn_metadata
.
prefill_wrapper
=
state
.
_get_prefill_wrapper
(
)
model_input
.
attn_metadata
.
decode_wrapper
=
state
.
_get_decode_wrapper
()
model_input
.
attn_metadata
.
begin_forward
()
@
dataclass
@
dataclass
class
FlashInferMetadata
(
AttentionMetadata
):
class
FlashInferMetadata
(
AttentionMetadata
):
# Maximum sequence length among prefill batch. 0 if there are decoding
# Maximum sequence length among prefill batch. 0 if there are decoding
...
@@ -116,6 +282,7 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -116,6 +282,7 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
data_type
:
torch
.
dtype
=
None
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
is_profile_run
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Refer to
# Refer to
...
@@ -139,20 +306,20 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -139,20 +306,20 @@ class FlashInferMetadata(AttentionMetadata):
assert
self
.
paged_kv_last_page_len
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
assert
batch_size
>=
0
assert
batch_size
>=
0
#
The prefill stage does not read kv cache.
#
We will use flash attention for profiling to
#
Both paged_kv_indices and paged_kv_last_page_len are empty.
#
determine the number of blocks. Therefore,
#
paged_kv_indptr is a zero tensor with size batch_size + 1
.
#
we don't need to prepare the input for flashinfer for profile run
.
self
.
paged_kv_indptr
=
torch
.
zeros
(
batch_size
+
1
,
if
not
self
.
is_profile_run
:
device
=
self
.
device
)
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
self
.
device
)
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
prefill_wrapper
.
end_forward
()
self
.
prefill_wrapper
.
end_forward
()
self
.
prefill_wrapper
.
begin_forward
(
self
.
prefill_wrapper
.
begin_forward
(
self
.
query_start_loc
,
self
.
paged_kv_indptr
,
self
.
query_start_loc
,
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
)
self
.
page_size
)
else
:
else
:
if
not
self
.
use_cuda_graph
:
if
not
self
.
use_cuda_graph
:
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indices
is
not
None
...
@@ -244,6 +411,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -244,6 +411,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# paged_kv_last_page_len is the length of the last page of each request
# paged_kv_last_page_len is the length of the last page of each request
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
self
.
is_profile_run
:
bool
=
False
def
_add_seq_group
(
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
chunked_prefill_enabled
:
bool
):
...
@@ -300,6 +469,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -300,6 +469,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# and paged_kv_last_page_len for profile run because we will
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
# create dummy inputs.
if
is_profile_run
:
if
is_profile_run
:
self
.
is_profile_run
=
is_profile_run
return
return
block_table
=
block_tables
[
seq_id
]
block_table
=
block_tables
[
seq_id
]
...
@@ -356,7 +526,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -356,7 +526,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
,
non_blocking
=
True
)
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
...
@@ -371,12 +542,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -371,12 +542,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
assert
device
is
not
None
dtype
=
torch
.
int
,
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
device
=
device
)
self
.
runner
.
pin_memory
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
dtype
=
torch
.
long
,
self
.
runner
.
pin_memory
)
device
=
device
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -392,10 +564,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -392,10 +564,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
query_start_loc
.
dtype
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
if
len
(
self
.
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -432,7 +600,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -432,7 +600,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
device
,
device
=
device
,
data_type
=
kv_cache_dtype
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
)
use_cuda_graph
=
use_captured_graph
,
is_profile_run
=
self
.
is_profile_run
)
class
FlashInferImpl
(
AttentionImpl
):
class
FlashInferImpl
(
AttentionImpl
):
...
@@ -516,7 +685,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -516,7 +685,7 @@ class FlashInferImpl(AttentionImpl):
# This happens when vllm runs the profiling to
# This happens when vllm runs the profiling to
# determine the number of blocks.
# determine the number of blocks.
if
kv_cache
is
None
:
if
kv_cache
is
None
:
output
=
flash_attn_varlen_func
(
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
...
vllm/attention/backends/ipex_attn.py
View file @
af7f4372
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
vllm._ipex_ops
import
ipex_ops
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
...
@@ -28,6 +29,10 @@ class IpexAttnBackend(AttentionBackend):
...
@@ -28,6 +29,10 @@ class IpexAttnBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"IpexAttnMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"IpexAttnMetadata"
]:
return
IpexAttnMetadata
return
IpexAttnMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/openvino.py
View file @
af7f4372
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
,
Type
import
openvino
as
ov
import
openvino
as
ov
import
torch
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
AttentionMetadata
)
from
vllm.attention.backends.utils
import
CommonAttentionState
class
OpenVINOAttentionBackend
(
AttentionBackend
):
class
OpenVINOAttentionBackend
(
AttentionBackend
):
...
@@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
...
@@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
make_openvino_metadata
(
*
args
,
**
kwargs
)
->
"OpenVINOAttentionMetadata"
:
def
make_openvino_metadata
(
*
args
,
**
kwargs
)
->
"OpenVINOAttentionMetadata"
:
return
OpenVINOAttentionMetadata
(
*
args
,
**
kwargs
)
return
OpenVINOAttentionMetadata
(
*
args
,
**
kwargs
)
...
...
vllm/attention/backends/pallas.py
View file @
af7f4372
...
@@ -6,6 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
...
@@ -6,6 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
class
PallasAttentionBackend
(
AttentionBackend
):
class
PallasAttentionBackend
(
AttentionBackend
):
...
@@ -18,6 +19,10 @@ class PallasAttentionBackend(AttentionBackend):
...
@@ -18,6 +19,10 @@ class PallasAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"PallasMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"PallasMetadata"
]:
return
PallasMetadata
return
PallasMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
af7f4372
...
@@ -7,7 +7,8 @@ import torch
...
@@ -7,7 +7,8 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -33,6 +34,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -33,6 +34,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"ROCmFlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"ROCmFlashAttentionMetadataBuilder"
]:
return
ROCmFlashAttentionMetadataBuilder
return
ROCmFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -508,6 +513,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -508,6 +513,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
,
query
,
key
,
key
,
value
,
value
,
self
.
kv_cache_dtype
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
...
@@ -517,6 +523,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -517,6 +523,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta
.
max_query_len
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
self
.
sliding_window
[
0
],
k_scale
,
v_scale
,
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
vllm/attention/backends/torch_sdpa.py
View file @
af7f4372
...
@@ -8,6 +8,7 @@ from torch.nn.functional import scaled_dot_product_attention
...
@@ -8,6 +8,7 @@ from torch.nn.functional import scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.utils
import
is_cpu
from
vllm.utils
import
is_cpu
...
@@ -34,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -34,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
TorchSDPAMetadata
return
TorchSDPAMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
...
vllm/attention/backends/utils.py
View file @
af7f4372
"""Attention backend utils"""
"""Attention backend utils"""
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Type
,
TypeVar
,
Union
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
import
numpy
as
np
import
torch
import
torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
from
vllm.utils
import
make_tensor_with_pad
AttentionState
)
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerBase
# Error string(s) for encoder/decoder
# Error string(s) for encoder/decoder
# unsupported attention scenarios
# unsupported attention scenarios
...
@@ -13,6 +19,10 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
...
@@ -13,6 +19,10 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
PAD_SLOT_ID
=
-
1
PAD_SLOT_ID
=
-
1
# Switch to numpy implementation of compute_slot_mapping
# if we have at least this many elements. Could be tuned further.
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL
=
256
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
...
@@ -46,6 +56,29 @@ def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
...
@@ -46,6 +56,29 @@ def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
return
start_idx
return
start_idx
def
_compute_slot_mapping_python
(
slot_mapping
:
List
[
int
],
block_table
:
List
[
int
],
range_start
:
int
,
range_end
:
int
,
block_size
:
int
):
for
i
in
range
(
range_start
,
range_end
):
block_number
=
block_table
[
i
//
block_size
]
block_offset
=
i
%
block_size
slot
=
block_number
*
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
def
_compute_slot_mapping_numpy
(
slot_mapping
:
List
[
int
],
block_table
:
List
[
int
],
range_start
:
int
,
range_end
:
int
,
block_size
:
int
):
block_table_array
=
np
.
array
(
block_table
)
idx
=
np
.
arange
(
range_start
,
range_end
)
block_offset
=
idx
%
block_size
idx
//=
block_size
seq_slot_mapping_array
=
block_table_array
[
idx
]
seq_slot_mapping_array
*=
block_size
seq_slot_mapping_array
+=
block_offset
slot_mapping
.
extend
(
seq_slot_mapping_array
)
def
compute_slot_mapping
(
is_profile_run
:
bool
,
slot_mapping
:
List
[
int
],
def
compute_slot_mapping
(
is_profile_run
:
bool
,
slot_mapping
:
List
[
int
],
seq_id
:
int
,
seq_len
:
int
,
context_len
:
int
,
seq_id
:
int
,
seq_len
:
int
,
context_len
:
int
,
start_idx
:
int
,
block_size
:
int
,
start_idx
:
int
,
block_size
:
int
,
...
@@ -67,13 +100,22 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
...
@@ -67,13 +100,22 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
# sliding window is 8, and block size is 4, the first two
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
padding_mask_len
=
max
(
0
,
start_idx
-
context_len
)
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
padding_mask_len
)
range_start
=
max
(
start_idx
,
context_len
)
range_end
=
seq_len
numel
=
range_end
-
range_start
block_table
=
block_tables
[
seq_id
]
block_table
=
block_tables
[
seq_id
]
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
max
(
0
,
start_idx
-
context_len
))
for
i
in
range
(
max
(
start_idx
,
context_len
),
seq_len
):
# numpy implementation will be faster than python if we have
block_number
=
block_table
[
i
//
block_size
]
# many elements, otherwise it will be slower.
block_offset
=
i
%
block_size
if
numel
<
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL
:
slot
=
block_number
*
block_size
+
block_offset
_compute_slot_mapping_python
(
slot_mapping
,
block_table
,
range_start
,
slot_mapping
.
append
(
slot
)
range_end
,
block_size
)
else
:
_compute_slot_mapping_numpy
(
slot_mapping
,
block_table
,
range_start
,
range_end
,
block_size
)
TAttentionMetadata
=
TypeVar
(
"TAttentionMetadata"
,
bound
=
'AttentionMetadata'
)
TAttentionMetadata
=
TypeVar
(
"TAttentionMetadata"
,
bound
=
'AttentionMetadata'
)
...
@@ -181,7 +223,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -181,7 +223,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
,
non_blocking
=
True
)
else
:
else
:
block_tables
=
make_tensor_with_pad
(
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
self
.
block_tables
,
...
@@ -191,15 +234,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -191,15 +234,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
)
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
assert
device
is
not
None
dtype
=
torch
.
int
,
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
=
device
)
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
seq_lens_tensor
=
async_
tensor
_h2d
(
seq_lens
,
torch
.
int
,
device
,
dtype
=
torch
.
int
,
self
.
runner
.
pin_memory
)
device
=
device
)
query_lens_tensor
=
async_tensor_h2d
(
query_lens
,
torch
.
long
,
device
,
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
self
.
runner
.
pin_memory
)
dtype
=
torch
.
long
,
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
=
device
)
device
,
self
.
runner
.
pin_memory
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
...
@@ -215,10 +258,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -215,10 +258,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
dtype
=
query_start_loc
.
dtype
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
out
=
query_start_loc
[
1
:])
slot_mapping_tensor
=
torch
.
tensor
(
self
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
return
self
.
_metadata_cls
(
# type: ignore
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
...
@@ -235,3 +274,69 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -235,3 +274,69 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
)
)
class
CommonAttentionState
(
AttentionState
):
def
__init__
(
self
,
runner
:
"ModelRunnerBase"
):
self
.
runner
=
runner
self
.
_is_graph_capturing
=
False
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
self
.
_is_graph_capturing
=
True
self
.
_graph_slot_mapping
=
torch
.
full
((
max_batch_size
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
self
.
_graph_seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_block_tables
=
torch
.
from_numpy
(
self
.
runner
.
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
)
yield
self
.
_is_graph_capturing
=
False
del
self
.
_graph_slot_mapping
del
self
.
_graph_seq_lens
del
self
.
_graph_block_tables
def
graph_clone
(
self
,
batch_size
:
int
)
->
"CommonAttentionState"
:
assert
self
.
_is_graph_capturing
return
self
.
__class__
(
self
.
runner
)
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
):
assert
self
.
_is_graph_capturing
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
seq_lens
=
None
,
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
runner
.
max_seq_len_to_capture
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
_graph_block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
)
return
attn_metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
)
->
Dict
[
str
,
Any
]:
return
{
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"seq_lens_tensor"
:
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
)
->
None
:
input_buffers
[
"seq_lens_tensor"
].
copy_
(
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
def
begin_forward
(
self
,
model_input
)
->
None
:
return
vllm/attention/backends/xformers.py
View file @
af7f4372
...
@@ -11,7 +11,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
...
@@ -11,7 +11,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonMetadataBuilder
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -37,6 +38,10 @@ class XFormersBackend(AttentionBackend):
...
@@ -37,6 +38,10 @@ class XFormersBackend(AttentionBackend):
def
get_builder_cls
()
->
Type
[
"XFormersMetadataBuilder"
]:
def
get_builder_cls
()
->
Type
[
"XFormersMetadataBuilder"
]:
return
XFormersMetadataBuilder
return
XFormersMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -604,6 +609,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -604,6 +609,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query
,
query
,
key
,
key
,
value
,
value
,
self
.
kv_cache_dtype
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
...
@@ -613,6 +619,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -613,6 +619,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
prefill_meta
.
max_query_len
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
sliding_window
,
self
.
sliding_window
,
k_scale
,
v_scale
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_tokens
]
=
out
...
...
vllm/attention/layer.py
View file @
af7f4372
...
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
...
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention
.backends.abstract
import
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
vllm/attention/ops/ipex_attn.py
View file @
af7f4372
...
@@ -90,6 +90,7 @@ class PagedAttention:
...
@@ -90,6 +90,7 @@ class PagedAttention:
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
...
...
vllm/attention/ops/paged_attn.py
View file @
af7f4372
...
@@ -256,6 +256,7 @@ class PagedAttention:
...
@@ -256,6 +256,7 @@ class PagedAttention:
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
...
@@ -265,6 +266,8 @@ class PagedAttention:
...
@@ -265,6 +266,8 @@ class PagedAttention:
max_query_len
:
int
,
max_query_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
k_scale
:
float
,
v_scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
context_attention_fwd
(
context_attention_fwd
(
...
@@ -272,6 +275,7 @@ class PagedAttention:
...
@@ -272,6 +275,7 @@ class PagedAttention:
key
,
key
,
value
,
value
,
output
,
output
,
kv_cache_dtype
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
block_tables
,
block_tables
,
...
@@ -280,6 +284,8 @@ class PagedAttention:
...
@@ -280,6 +284,8 @@ class PagedAttention:
seq_lens_tensor
,
seq_lens_tensor
,
context_lens
,
context_lens
,
max_query_len
,
max_query_len
,
k_scale
,
v_scale
,
alibi_slopes
,
alibi_slopes
,
sliding_window
,
sliding_window
,
)
)
...
...
vllm/attention/ops/prefix_prefill.py
View file @
af7f4372
...
@@ -18,6 +18,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -18,6 +18,8 @@ if triton.__version__ >= "2.1.0":
V_cache
,
V_cache
,
B_Loc
,
B_Loc
,
sm_scale
,
sm_scale
,
k_scale
,
v_scale
,
B_Start_Loc
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
B_Ctxlen
,
B_Ctxlen
,
...
@@ -117,10 +119,15 @@ if triton.__version__ >= "2.1.0":
...
@@ -117,10 +119,15 @@ if triton.__version__ >= "2.1.0":
cur_kv_head
*
stride_v_cache_h
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
k_load
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
dim_mask
[:,
None
]
&
mask
=
dim_mask
[:,
None
]
&
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
# [D,N]
other
=
0.0
)
# [D,N]
if
k_load
.
dtype
.
is_fp8
():
k
=
(
k_load
.
to
(
tl
.
float32
)
*
k_scale
).
to
(
q
.
dtype
)
else
:
k
=
k_load
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# [M,N]
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# [M,N]
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
...
@@ -161,12 +168,16 @@ if triton.__version__ >= "2.1.0":
...
@@ -161,12 +168,16 @@ if triton.__version__ >= "2.1.0":
acc_scale
=
l_i
/
l_i_new
*
alpha
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
v_load
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
dim_mask
[
None
,
:]
&
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
# [N,D]
other
=
0.0
)
# [N,D]
if
v_load
.
dtype
.
is_fp8
():
v
=
(
v_load
.
to
(
tl
.
float32
)
*
v_scale
).
to
(
q
.
dtype
)
else
:
v
=
v_load
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
acc
+=
tl
.
dot
(
p
,
v
)
# # update m_i and l_i
# # update m_i and l_i
l_i
=
l_i_new
l_i
=
l_i_new
...
@@ -225,8 +236,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -225,8 +236,8 @@ if triton.__version__ >= "2.1.0":
mask
=
dim_mask
[
None
,
:]
&
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_query_len
),
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_query_len
),
other
=
0.0
)
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
# update m_i and l_i
l_i
=
l_i_new
l_i
=
l_i_new
...
@@ -336,7 +347,6 @@ if triton.__version__ >= "2.1.0":
...
@@ -336,7 +347,6 @@ if triton.__version__ >= "2.1.0":
k
=
tl
.
load
(
K_cache
+
off_k
,
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
other
=
0.0
)
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
...
@@ -442,6 +452,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -442,6 +452,8 @@ if triton.__version__ >= "2.1.0":
V_cache
,
V_cache
,
B_Loc
,
B_Loc
,
sm_scale
,
sm_scale
,
k_scale
,
v_scale
,
B_Start_Loc
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
B_Ctxlen
,
B_Ctxlen
,
...
@@ -537,10 +549,15 @@ if triton.__version__ >= "2.1.0":
...
@@ -537,10 +549,15 @@ if triton.__version__ >= "2.1.0":
cur_kv_head
*
stride_v_cache_h
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
k_load
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
dim_mask
[:,
None
]
&
mask
=
dim_mask
[:,
None
]
&
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
),
other
=
0.0
)
# [D,N]
other
=
0.0
)
# [D,N]
if
k_load
.
dtype
.
is_fp8
():
k
=
(
k_load
.
to
(
tl
.
float32
)
*
k_scale
).
to
(
q
.
dtype
)
else
:
k
=
k_load
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
...
@@ -573,12 +590,16 @@ if triton.__version__ >= "2.1.0":
...
@@ -573,12 +590,16 @@ if triton.__version__ >= "2.1.0":
# acc_scale = l_i / l_i_new * alpha
# acc_scale = l_i / l_i_new * alpha
acc
=
acc
*
acc_scale
[:,
None
]
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
v_load
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
dim_mask
[
None
,
:]
&
mask
=
dim_mask
[
None
,
:]
&
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
if
v_load
.
dtype
.
is_fp8
():
v
=
(
v_load
.
to
(
tl
.
float32
)
*
v_scale
).
to
(
q
.
dtype
)
else
:
v
=
v_load
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
,
allow_tf32
=
False
)
acc
+=
tl
.
dot
(
p
,
v
,
allow_tf32
=
False
)
# update m_i and l_i
# update m_i and l_i
l_i
=
l_i_new
l_i
=
l_i_new
...
@@ -650,8 +671,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -650,8 +671,8 @@ if triton.__version__ >= "2.1.0":
((
start_n
+
offs_n
[:,
None
])
<
((
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
-
cur_batch_ctx_len
),
cur_batch_seq_len
-
cur_batch_ctx_len
),
other
=
0.0
)
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
,
allow_tf32
=
False
)
acc
+=
tl
.
dot
(
p
,
v
,
allow_tf32
=
False
)
# update m_i and l_i
# update m_i and l_i
l_i
=
l_i_new
l_i
=
l_i_new
...
@@ -675,6 +696,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -675,6 +696,7 @@ if triton.__version__ >= "2.1.0":
k
,
k
,
v
,
v
,
o
,
o
,
kv_cache_dtype
:
str
,
k_cache
,
k_cache
,
v_cache
,
v_cache
,
b_loc
,
b_loc
,
...
@@ -682,17 +704,41 @@ if triton.__version__ >= "2.1.0":
...
@@ -682,17 +704,41 @@ if triton.__version__ >= "2.1.0":
b_seq_len
,
b_seq_len
,
b_ctx_len
,
b_ctx_len
,
max_input_len
,
max_input_len
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
alibi_slopes
=
None
,
alibi_slopes
=
None
,
sliding_window
=
None
):
sliding_window
=
None
):
cap
=
current_platform
.
get_device_capability
()
cap
=
current_platform
.
get_device_capability
()
BLOCK
=
32
if
cap
[
0
]
>=
8
else
32
BLOCK
=
32
if
cap
[
0
]
>=
8
else
32
NUM_WARPS
=
8
# need to reduce num. blocks when using fp32
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# due to increased use of GPU shared memory
if
q
.
dtype
is
torch
.
float32
:
if
q
.
dtype
is
torch
.
float32
:
BLOCK
=
BLOCK
//
2
BLOCK
=
BLOCK
//
2
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if
"fp8"
in
kv_cache_dtype
:
assert
(
k_cache
.
dtype
==
torch
.
uint8
)
assert
(
v_cache
.
dtype
==
torch
.
uint8
)
if
kv_cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
):
target_dtype
=
torch
.
float8_e4m3fn
elif
kv_cache_dtype
==
"fp8_e5m2"
:
target_dtype
=
torch
.
float8_e5m2
else
:
raise
ValueError
(
"Unsupported FP8 dtype:"
,
kv_cache_dtype
)
k_cache
=
k_cache
.
view
(
target_dtype
)
v_cache
=
v_cache
.
view
(
target_dtype
)
if
(
k_cache
.
dtype
==
torch
.
uint8
or
v_cache
.
dtype
==
torch
.
uint8
and
kv_cache_dtype
==
"auto"
):
raise
ValueError
(
"kv_cache_dtype='auto' unsupported for
\
FP8 KV Cache prefill kernel"
)
# shape constraints
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lq
==
Lk
and
Lk
==
Lv
...
@@ -709,7 +755,6 @@ if triton.__version__ >= "2.1.0":
...
@@ -709,7 +755,6 @@ if triton.__version__ >= "2.1.0":
if
sliding_window
is
None
or
sliding_window
<=
0
:
if
sliding_window
is
None
or
sliding_window
<=
0
:
sliding_window
=
0
sliding_window
=
0
num_warps
=
8
if
Lk
<=
64
else
4
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
_fwd_kernel_alibi
[
grid
](
_fwd_kernel_alibi
[
grid
](
q
,
q
,
...
@@ -719,6 +764,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -719,6 +764,8 @@ if triton.__version__ >= "2.1.0":
v_cache
,
v_cache
,
b_loc
,
b_loc
,
sm_scale
,
sm_scale
,
k_scale
,
v_scale
,
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
b_ctx_len
,
b_ctx_len
,
...
@@ -757,7 +804,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -757,7 +804,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
NUM_WARPS
,
num_stages
=
1
,
num_stages
=
1
,
)
)
return
return
...
@@ -770,6 +817,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -770,6 +817,8 @@ if triton.__version__ >= "2.1.0":
v_cache
,
v_cache
,
b_loc
,
b_loc
,
sm_scale
,
sm_scale
,
k_scale
,
v_scale
,
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
b_ctx_len
,
b_ctx_len
,
...
@@ -807,7 +856,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -807,7 +856,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
SLIDING_WINDOW
=
sliding_window
,
SLIDING_WINDOW
=
sliding_window
,
num_warps
=
num_warps
,
num_warps
=
NUM_WARPS
,
num_stages
=
1
,
num_stages
=
1
,
)
)
return
return
vllm/attention/selector.py
View file @
af7f4372
import
enum
import
enum
import
os
from
contextlib
import
contextmanager
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Optional
,
Type
from
typing
import
Generator
,
Optional
,
Type
import
torch
import
torch
...
@@ -8,7 +10,7 @@ import vllm.envs as envs
...
@@ -8,7 +10,7 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_openvino
,
is_tpu
,
is_xpu
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
is_cpu
,
is_hip
,
is_openvino
,
is_xpu
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -24,6 +26,66 @@ class _Backend(enum.Enum):
...
@@ -24,6 +26,66 @@ class _Backend(enum.Enum):
IPEX
=
enum
.
auto
()
IPEX
=
enum
.
auto
()
def
backend_name_to_enum
(
backend_name
:
str
)
->
_Backend
:
assert
backend_name
is
not
None
backend_members
=
_Backend
.
__members__
if
backend_name
not
in
backend_members
:
raise
ValueError
(
f
"Invalid attention backend '
{
backend_name
}
'. "
f
"Available backends:
{
', '
.
join
(
backend_members
)
}
"
"(case-sensitive)."
)
return
_Backend
[
backend_name
]
def
get_env_variable_attn_backend
()
->
Optional
[
_Backend
]:
'''
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* _Backend enum value if an override is specified
* None otherwise
'''
backend_name
=
os
.
environ
.
get
(
STR_BACKEND_ENV_VAR
)
return
(
None
if
backend_name
is
None
else
backend_name_to_enum
(
backend_name
))
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend
:
Optional
[
_Backend
]
=
None
def
global_force_attn_backend
(
attn_backend
:
Optional
[
_Backend
])
->
None
:
'''
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
'''
global
forced_attn_backend
forced_attn_backend
=
attn_backend
def
get_global_forced_attn_backend
()
->
Optional
[
_Backend
]:
'''
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
'''
return
forced_attn_backend
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_attn_backend
(
def
get_attn_backend
(
num_heads
:
int
,
num_heads
:
int
,
...
@@ -101,16 +163,20 @@ def which_attn_to_use(
...
@@ -101,16 +163,20 @@ def which_attn_to_use(
# Default case.
# Default case.
selected_backend
=
_Backend
.
FLASH_ATTN
selected_backend
=
_Backend
.
FLASH_ATTN
# Check the environment variable and override if specified
# Check whether a particular choice of backend was
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
# previously forced.
if
backend_by_env_var
is
not
None
:
#
backend_members
=
_Backend
.
__members__
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
if
backend_by_env_var
not
in
backend_members
:
# ENVIRONMENT VARIABLE.
raise
ValueError
(
backend_by_global_setting
:
Optional
[
_Backend
]
=
(
f
"Invalid attention backend '
{
backend_by_env_var
}
'. "
get_global_forced_attn_backend
())
f
"Available backends:
{
', '
.
join
(
backend_members
)
}
"
if
backend_by_global_setting
is
not
None
:
"(case-sensitive)."
)
selected_backend
=
backend_by_global_setting
selected_backend
=
_Backend
[
backend_by_env_var
]
else
:
# Check the environment variable and override if specified
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
is_cpu
():
if
is_cpu
():
if
selected_backend
!=
_Backend
.
TORCH_SDPA
:
if
selected_backend
!=
_Backend
.
TORCH_SDPA
:
...
@@ -127,7 +193,7 @@ def which_attn_to_use(
...
@@ -127,7 +193,7 @@ def which_attn_to_use(
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
return
_Backend
.
IPEX
return
_Backend
.
IPEX
if
is_tpu
():
if
current_platform
.
is_tpu
():
if
selected_backend
!=
_Backend
.
PALLAS
:
if
selected_backend
!=
_Backend
.
PALLAS
:
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
return
_Backend
.
PALLAS
return
_Backend
.
PALLAS
...
@@ -193,3 +259,35 @@ def which_attn_to_use(
...
@@ -193,3 +259,35 @@ def which_attn_to_use(
selected_backend
=
_Backend
.
XFORMERS
selected_backend
=
_Backend
.
XFORMERS
return
selected_backend
return
selected_backend
@
contextmanager
def
global_force_attn_backend_context_manager
(
attn_backend
:
_Backend
)
->
Generator
[
None
,
None
,
None
]:
'''
Globally force a vLLM attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
'''
# Save the current state of the global backend override (if any)
original_value
=
get_global_forced_attn_backend
()
# Globally force the new backend override
global_force_attn_backend
(
attn_backend
)
# Yield control back to the enclosed code block
try
:
yield
finally
:
# Revert the original global backend override, if any
global_force_attn_backend
(
original_value
)
vllm/block.py
View file @
af7f4372
"""Token blocks."""
"""Token blocks."""
from
typing
import
List
from
typing
import
List
,
Optional
from
vllm.utils
import
Device
from
vllm.utils
import
Device
...
@@ -37,5 +37,47 @@ class PhysicalTokenBlock:
...
@@ -37,5 +37,47 @@ class PhysicalTokenBlock:
f
'computed=
{
self
.
computed
}
)'
)
f
'computed=
{
self
.
computed
}
)'
)
# Mapping: logical block number -> physical block.
class
BlockTable
:
BlockTable
=
List
[
PhysicalTokenBlock
]
"""Holds a list of blocks with caching of their associated block_ids
"""
def
__init__
(
self
,
blocks
:
Optional
[
List
[
PhysicalTokenBlock
]]
=
None
):
self
.
_blocks
:
List
[
PhysicalTokenBlock
]
=
[]
self
.
_block_ids
:
List
[
int
]
=
[]
if
blocks
is
not
None
:
for
block
in
blocks
:
self
.
append
(
block
)
def
append
(
self
,
block
:
PhysicalTokenBlock
):
self
.
_blocks
.
append
(
block
)
self
.
_block_ids
.
append
(
block
.
block_number
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_blocks
)
def
__getitem__
(
self
,
key
):
return
self
.
_blocks
[
key
]
def
__setitem__
(
self
,
key
,
value
):
if
isinstance
(
key
,
slice
):
blocks
=
value
self
.
_blocks
[
key
]
=
blocks
self
.
_block_ids
[
key
]
=
[
b
.
block_number
for
b
in
blocks
]
else
:
block
=
value
self
.
_blocks
[
key
]
=
block
self
.
_block_ids
[
key
]
=
block
.
block_number
def
reset
(
self
):
self
.
_blocks
=
[]
self
.
_block_ids
=
[]
def
copy
(
self
)
->
"BlockTable"
:
return
BlockTable
(
self
.
_blocks
)
def
list
(
self
)
->
List
[
PhysicalTokenBlock
]:
return
self
.
_blocks
def
ids
(
self
)
->
List
[
int
]:
return
self
.
_block_ids
Prev
1
…
10
11
12
13
14
15
16
17
18
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment