Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2daf23ab
Unverified
Commit
2daf23ab
authored
Mar 07, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 07, 2024
Browse files
Separate attention backends (#3005)
parent
cbf4c05b
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
503 additions
and
213 deletions
+503
-213
.gitignore
.gitignore
+3
-0
setup.py
setup.py
+45
-3
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+1
-1
vllm/__init__.py
vllm/__init__.py
+23
-7
vllm/model_executor/layers/attention/__init__.py
vllm/model_executor/layers/attention/__init__.py
+5
-0
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+59
-0
vllm/model_executor/layers/attention/backends/__init__.py
vllm/model_executor/layers/attention/backends/__init__.py
+0
-0
vllm/model_executor/layers/attention/backends/flash_attn.py
vllm/model_executor/layers/attention/backends/flash_attn.py
+124
-0
vllm/model_executor/layers/attention/backends/xformers.py
vllm/model_executor/layers/attention/backends/xformers.py
+61
-155
vllm/model_executor/layers/attention/ops/__init__.py
vllm/model_executor/layers/attention/ops/__init__.py
+0
-0
vllm/model_executor/layers/attention/ops/paged_attn.py
vllm/model_executor/layers/attention/ops/paged_attn.py
+138
-0
vllm/model_executor/layers/attention/ops/prefix_prefill.py
vllm/model_executor/layers/attention/ops/prefix_prefill.py
+0
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+6
-7
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+5
-5
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+2
-2
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+5
-5
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+14
-14
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+5
-5
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+2
-4
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+5
-5
No files found.
.gitignore
View file @
2daf23ab
...
...
@@ -184,3 +184,6 @@ _build/
# Benchmark dataset
*.json
# Third-party Python packages.
vllm/thirdparty_files/
setup.py
View file @
2daf23ab
...
...
@@ -3,6 +3,7 @@ import io
import
os
import
re
import
subprocess
import
sys
import
warnings
from
pathlib
import
Path
from
typing
import
List
,
Set
...
...
@@ -14,6 +15,8 @@ import torch.utils.cpp_extension as torch_cpp_ext
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR
=
"vllm/thirdparty_files"
# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
...
...
@@ -324,8 +327,46 @@ if _is_cuda():
"nvcc"
:
NVCC_FLAGS_PUNICA
,
},
))
elif
_is_neuron
():
neuronxcc_version
=
get_neuronxcc_version
()
# Download the FlashAttention package.
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version
=
"2.5.6"
install_dir
=
os
.
path
.
join
(
ROOT_DIR
,
THIRDPARTY_SUBDIR
)
subprocess
.
check_call
(
[
sys
.
executable
,
"-m"
,
"pip"
,
"install"
,
"-q"
,
f
"--target=
{
install_dir
}
"
,
"einops"
,
# Dependency of flash-attn.
f
"flash-attn==
{
flash_attn_version
}
"
,
"--no-dependencies"
,
# Required to avoid re-installing torch.
],
env
=
dict
(
os
.
environ
,
CC
=
"gcc"
),
)
# Copy the FlashAttention package into the vLLM package after build.
class
build_ext
(
BuildExtension
):
def
run
(
self
):
super
().
run
()
target_dir
=
os
.
path
.
join
(
self
.
build_lib
,
THIRDPARTY_SUBDIR
)
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
self
.
copy_tree
(
install_dir
,
target_dir
)
class
BinaryDistribution
(
setuptools
.
Distribution
):
def
has_ext_modules
(
self
):
return
True
else
:
build_ext
=
BuildExtension
BinaryDistribution
=
setuptools
.
Distribution
if
_is_neuron
():
neuronxcc_version
=
get_neuronxcc_version
()
vllm_extension_sources
=
[
"csrc/cache_kernels.cu"
,
...
...
@@ -468,6 +509,7 @@ setuptools.setup(
python_requires
=
">=3.8"
,
install_requires
=
get_requirements
(),
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
}
if
not
_is_neuron
()
else
{},
cmdclass
=
{
"build_ext"
:
build_ext
}
if
not
_is_neuron
()
else
{},
distclass
=
BinaryDistribution
,
package_data
=
package_data
,
)
tests/kernels/test_prefix_prefill.py
View file @
2daf23ab
...
...
@@ -3,7 +3,7 @@ import pytest
import
time
import
torch
from
vllm.model_executor.layers.
triton_kernel
.prefix_prefill
import
(
from
vllm.model_executor.layers.
attention.ops
.prefix_prefill
import
(
context_attention_fwd
)
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
...
...
vllm/__init__.py
View file @
2daf23ab
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.ray_utils
import
initialize_cluster
from
vllm.entrypoints.llm
import
LLM
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
SamplingParams
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
def
_configure_system
():
import
os
import
sys
# Importing flash-attn.
thirdparty_files
=
os
.
path
.
join
(
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
)),
"thirdparty_files"
)
sys
.
path
.
insert
(
0
,
thirdparty_files
)
_configure_system
()
# Delete configuration function.
del
_configure_system
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
# noqa: E402
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# noqa: E402
from
vllm.engine.llm_engine
import
LLMEngine
# noqa: E402
from
vllm.engine.ray_utils
import
initialize_cluster
# noqa: E402
from
vllm.entrypoints.llm
import
LLM
# noqa: E402
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
# noqa: E402
from
vllm.sampling_params
import
SamplingParams
# noqa: E402
__version__
=
"0.3.3"
...
...
vllm/model_executor/layers/attention/__init__.py
0 → 100644
View file @
2daf23ab
from
vllm.model_executor.layers.attention.attention
import
Attention
__all__
=
[
"Attention"
,
]
vllm/model_executor/layers/attention/attention.py
0 → 100644
View file @
2daf23ab
"""Attention layer."""
from
typing
import
List
,
Optional
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.utils
import
is_hip
class
Attention
(
nn
.
Module
):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
if
(
not
is_hip
()
and
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
and
torch
.
get_default_dtype
()
in
(
torch
.
float16
,
torch
.
bfloat16
)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
from
vllm.model_executor.layers.attention.backends.flash_attn
import
FlashAttentionBackend
self
.
backend
=
FlashAttentionBackend
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
)
else
:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
# Or FP32 on any GPU.
from
vllm.model_executor.layers.attention.backends.xformers
import
XFormersBackend
self
.
backend
=
XFormersBackend
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
Optional
[
torch
.
Tensor
],
value_cache
:
Optional
[
torch
.
Tensor
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
return
self
.
backend
.
forward
(
query
,
key
,
value
,
key_cache
,
value_cache
,
input_metadata
)
vllm/model_executor/layers/
triton_kernel
/__init__.py
→
vllm/model_executor/layers/
attention/backends
/__init__.py
View file @
2daf23ab
File moved
vllm/model_executor/layers/attention/backends/flash_attn.py
0 → 100644
View file @
2daf23ab
"""Attention layer with Flash and PagedAttention."""
from
typing
import
List
,
Optional
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from
flash_attn
import
flash_attn_func
import
torch
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention.ops.paged_attn
import
(
PagedAttentionImpl
)
class
FlashAttentionBackend
:
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttentionImpl
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
self
.
sliding_window
=
((
self
.
sliding_window
,
self
.
sliding_window
)
if
self
.
sliding_window
is
not
None
else
(
-
1
,
-
1
))
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
Optional
[
torch
.
Tensor
],
value_cache
:
Optional
[
torch
.
Tensor
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if
key_cache
is
not
None
and
value_cache
is
not
None
:
PagedAttentionImpl
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
)
if
input_metadata
.
is_prompt
:
# Prompt run.
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# normal attention
query
=
query
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
key
=
key
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
value
=
value
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
output
=
flash_attn_func
(
query
,
key
,
value
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
# prefix-enabled attention
output
=
PagedAttentionImpl
.
forward_prefix
(
query
,
key
,
value
,
key_cache
,
value_cache
,
input_metadata
,
self
.
num_heads
,
self
.
num_kv_heads
,
self
.
alibi_slopes
,
)
else
:
# Decoding run.
output
=
PagedAttentionImpl
.
forward_decode
(
query
,
key_cache
,
value_cache
,
input_metadata
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
)
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
vllm/model_executor/layers/attention.py
→
vllm/model_executor/layers/attention
/backends/xformers
.py
View file @
2daf23ab
"""Multi-head attention."""
"""Attention layer with xFormers and PagedAttention."""
import
importlib
from
typing
import
List
,
Optional
import
importlib
import
torch
import
torch.nn
as
nn
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
(
BlockDiagonalCausalMask
,
LowerTriangularMaskWithTensorBias
)
from
vllm._C
import
ops
from
vllm._C
import
cache_ops
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.
triton_kernel.prefix_prefill
import
(
context_a
ttention
_fwd
)
from
vllm.model_executor.layers.
attention.ops.paged_attn
import
(
PagedA
ttention
Impl
)
from
vllm.utils
import
is_hip
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
class
PagedAttention
(
nn
.
Module
):
"""MHA/MQA/GQA layer with PagedAttention.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Reshape and store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention using either
xformers or the PagedAttention custom op.
3. Return the output tensor.
"""
class
XFormersBackend
:
def
__init__
(
self
,
...
...
@@ -42,7 +24,6 @@ class PagedAttention(nn.Module):
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
@@ -50,48 +31,17 @@ class PagedAttention(nn.Module):
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
register_buffer
(
"
alibi_slopes
"
,
alibi_slopes
,
persistent
=
False
)
self
.
alibi_slopes
=
alibi_slopes
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttentionImpl
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
if
self
.
head_size
not
in
_SUPPORTED_HEAD_SIZES
:
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
f
"Supported head sizes:
{
_SUPPORTED_HEAD_SIZES
}
."
)
self
.
use_ref_attention
=
self
.
check_use_ref_attention
()
def
check_use_ref_attention
(
self
)
->
bool
:
if
not
is_hip
():
return
False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
None
def
ref_masked_attention
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
seq_len
,
_
,
_
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
query
.
dtype
,
device
=
query
.
device
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
attn_weights
=
self
.
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
self
.
use_ref_attention
=
_check_use_ref_attention
()
def
forward
(
self
,
...
...
@@ -102,7 +52,7 @@ class PagedAttention(nn.Module):
value_cache
:
Optional
[
torch
.
Tensor
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""PagedAttention
forward pass
.
"""
Forward pass with xFormers and
PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
...
...
@@ -127,19 +77,14 @@ class PagedAttention(nn.Module):
# vectors will not be cached. This happens during the initial memory
# profiling run.
if
key_cache
is
not
None
and
value_cache
is
not
None
:
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
.
flatten
(),
input_metadata
.
kv_cache_dtype
,
)
PagedAttentionImpl
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
)
if
input_metadata
.
is_prompt
:
#
normal attention
#
Prompt run.
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# normal attention
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
...
...
@@ -175,13 +120,19 @@ class PagedAttention(nn.Module):
seq_len
,
query
.
dtype
)
if
self
.
use_ref_attention
:
output
=
self
.
ref_masked_attention
(
output
=
_
ref_masked_attention
(
query
,
key
,
value
,
self
.
num_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
scale
,
)
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
# TODO(woosuk): Too many view operations. Let's try to reduce
...
...
@@ -206,27 +157,21 @@ class PagedAttention(nn.Module):
(
is_hip
())
else
None
,
)
output
=
out
.
view_as
(
query
)
else
:
# prefix-enabled attention
output
=
torch
.
empty_like
(
query
)
context_attention_fwd
(
output
=
PagedAttentionImpl
.
forward_prefix
(
query
,
key
,
value
,
output
,
key_cache
,
value_cache
,
input_metadata
.
block_tables
,
# [BS, max_block_per_request]
input_metadata
.
start_loc
,
input_metadata
.
prompt_lens
,
input_metadata
.
context_lens
,
input_metadata
.
max_seq_len
,
getattr
(
self
,
"alibi_slopes"
,
None
),
input_metadata
,
self
.
alibi_slopes
,
)
else
:
# Decoding run.
output
=
_p
aged
_a
ttention
(
output
=
P
aged
A
ttention
Impl
.
forward_decode
(
query
,
key_cache
,
value_cache
,
...
...
@@ -274,76 +219,37 @@ def _make_alibi_bias(
return
attn_bias
def
_paged_attention
(
def
_check_use_ref_attention
()
->
bool
:
if
not
is_hip
():
return
False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
None
def
_ref_masked_attention
(
query
:
torch
.
Tensor
,
key
_cache
:
torch
.
Tensor
,
value
_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
input_metadata
.
max_context_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
if
use_v1
:
# Run PagedAttention V1.
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
input_metadata
.
kv_cache_dtype
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
input_metadata
.
kv_cache_dtype
,
)
return
output
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
seq_len
,
_
,
_
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
query
.
dtype
,
device
=
query
.
device
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
vllm/model_executor/layers/attention/ops/__init__.py
0 → 100644
View file @
2daf23ab
vllm/model_executor/layers/attention/ops/paged_attn.py
0 → 100644
View file @
2daf23ab
from
typing
import
List
,
Optional
import
torch
from
vllm._C
import
cache_ops
from
vllm._C
import
ops
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention.ops.prefix_prefill
import
(
context_attention_fwd
)
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
class
PagedAttentionImpl
:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
256
]
@
staticmethod
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
.
flatten
(),
input_metadata
.
kv_cache_dtype
,
)
@
staticmethod
def
forward_decode
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
input_metadata
.
max_context_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
if
use_v1
:
# Run PagedAttention V1.
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
input_metadata
.
kv_cache_dtype
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
input_metadata
.
kv_cache_dtype
,
)
return
output
@
staticmethod
def
forward_prefix
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
context_attention_fwd
(
query
,
key
,
value
,
output
,
key_cache
,
value_cache
,
input_metadata
.
block_tables
,
# [BS, max_block_per_request]
input_metadata
.
start_loc
,
input_metadata
.
prompt_lens
,
input_metadata
.
context_lens
,
input_metadata
.
max_seq_len
,
alibi_slopes
,
)
return
output
vllm/model_executor/layers/
triton_kernel
/prefix_prefill.py
→
vllm/model_executor/layers/
attention/ops
/prefix_prefill.py
View file @
2daf23ab
File moved
vllm/model_executor/models/baichuan.py
View file @
2daf23ab
...
...
@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
...
...
@@ -151,10 +151,10 @@ class BaiChuanAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
else
:
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
...
...
@@ -163,8 +163,7 @@ class BaiChuanAttention(nn.Module):
base
=
self
.
rope_theta
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
self
,
...
...
vllm/model_executor/models/bloom.py
View file @
2daf23ab
...
...
@@ -25,7 +25,7 @@ from transformers import BloomConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -107,10 +107,10 @@ class BloomAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
=
alibi_slopes
)
def
forward
(
self
,
...
...
vllm/model_executor/models/chatglm.py
View file @
2daf23ab
...
...
@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
...
...
@@ -87,7 +87,7 @@ class GLMAttention(nn.Module):
base
=
10000
*
rope_ratio
,
is_neox_style
=
False
,
)
self
.
attn
=
Paged
Attention
(
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
...
...
vllm/model_executor/models/deepseek.py
View file @
2daf23ab
...
...
@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
...
...
@@ -229,10 +229,10 @@ class DeepseekAttention(nn.Module):
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
vllm/model_executor/models/falcon.py
View file @
2daf23ab
...
...
@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -150,10 +150,10 @@ class FalconAttention(nn.Module):
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
elif
self
.
use_alibi
:
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
...
...
@@ -161,16 +161,16 @@ class FalconAttention(nn.Module):
alibi_slopes
=
(
_get_alibi_slopes
(
self
.
total_num_heads
)
*
self
.
inv_norm_factor
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
,
alibi_slopes
=
alibi_slopes
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
,
alibi_slopes
=
alibi_slopes
)
else
:
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
vllm/model_executor/models/gemma.py
View file @
2daf23ab
...
...
@@ -23,7 +23,7 @@ from transformers import GemmaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
...
...
@@ -123,10 +123,10 @@ class GemmaAttention(nn.Module):
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
vllm/model_executor/models/gpt2.py
View file @
2daf23ab
...
...
@@ -25,7 +25,7 @@ from transformers import GPT2Config
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -73,9 +73,7 @@ class GPT2Attention(nn.Module):
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
def
forward
(
self
,
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
2daf23ab
...
...
@@ -26,7 +26,7 @@ from transformers import GPTBigCodeConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
QKVParallelLinear
,
...
...
@@ -85,10 +85,10 @@ class GPTBigCodeAttention(nn.Module):
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment