Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2daf23ab
Unverified
Commit
2daf23ab
authored
Mar 07, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 07, 2024
Browse files
Separate attention backends (#3005)
parent
cbf4c05b
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
503 additions
and
213 deletions
+503
-213
.gitignore
.gitignore
+3
-0
setup.py
setup.py
+45
-3
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+1
-1
vllm/__init__.py
vllm/__init__.py
+23
-7
vllm/model_executor/layers/attention/__init__.py
vllm/model_executor/layers/attention/__init__.py
+5
-0
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+59
-0
vllm/model_executor/layers/attention/backends/__init__.py
vllm/model_executor/layers/attention/backends/__init__.py
+0
-0
vllm/model_executor/layers/attention/backends/flash_attn.py
vllm/model_executor/layers/attention/backends/flash_attn.py
+124
-0
vllm/model_executor/layers/attention/backends/xformers.py
vllm/model_executor/layers/attention/backends/xformers.py
+61
-155
vllm/model_executor/layers/attention/ops/__init__.py
vllm/model_executor/layers/attention/ops/__init__.py
+0
-0
vllm/model_executor/layers/attention/ops/paged_attn.py
vllm/model_executor/layers/attention/ops/paged_attn.py
+138
-0
vllm/model_executor/layers/attention/ops/prefix_prefill.py
vllm/model_executor/layers/attention/ops/prefix_prefill.py
+0
-0
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+6
-7
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+5
-5
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+2
-2
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+5
-5
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+14
-14
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+5
-5
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+2
-4
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+5
-5
No files found.
.gitignore
View file @
2daf23ab
...
@@ -184,3 +184,6 @@ _build/
...
@@ -184,3 +184,6 @@ _build/
# Benchmark dataset
# Benchmark dataset
*.json
*.json
# Third-party Python packages.
vllm/thirdparty_files/
setup.py
View file @
2daf23ab
...
@@ -3,6 +3,7 @@ import io
...
@@ -3,6 +3,7 @@ import io
import
os
import
os
import
re
import
re
import
subprocess
import
subprocess
import
sys
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
,
Set
from
typing
import
List
,
Set
...
@@ -14,6 +15,8 @@ import torch.utils.cpp_extension as torch_cpp_ext
...
@@ -14,6 +15,8 @@ import torch.utils.cpp_extension as torch_cpp_ext
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR
=
"vllm/thirdparty_files"
# If you are developing the C++ backend of vLLM, consider building vLLM with
# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
# `python setup.py develop` since it will give you incremental builds.
...
@@ -324,8 +327,46 @@ if _is_cuda():
...
@@ -324,8 +327,46 @@ if _is_cuda():
"nvcc"
:
NVCC_FLAGS_PUNICA
,
"nvcc"
:
NVCC_FLAGS_PUNICA
,
},
},
))
))
elif
_is_neuron
():
neuronxcc_version
=
get_neuronxcc_version
()
# Download the FlashAttention package.
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version
=
"2.5.6"
install_dir
=
os
.
path
.
join
(
ROOT_DIR
,
THIRDPARTY_SUBDIR
)
subprocess
.
check_call
(
[
sys
.
executable
,
"-m"
,
"pip"
,
"install"
,
"-q"
,
f
"--target=
{
install_dir
}
"
,
"einops"
,
# Dependency of flash-attn.
f
"flash-attn==
{
flash_attn_version
}
"
,
"--no-dependencies"
,
# Required to avoid re-installing torch.
],
env
=
dict
(
os
.
environ
,
CC
=
"gcc"
),
)
# Copy the FlashAttention package into the vLLM package after build.
class
build_ext
(
BuildExtension
):
def
run
(
self
):
super
().
run
()
target_dir
=
os
.
path
.
join
(
self
.
build_lib
,
THIRDPARTY_SUBDIR
)
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
self
.
copy_tree
(
install_dir
,
target_dir
)
class
BinaryDistribution
(
setuptools
.
Distribution
):
def
has_ext_modules
(
self
):
return
True
else
:
build_ext
=
BuildExtension
BinaryDistribution
=
setuptools
.
Distribution
if
_is_neuron
():
neuronxcc_version
=
get_neuronxcc_version
()
vllm_extension_sources
=
[
vllm_extension_sources
=
[
"csrc/cache_kernels.cu"
,
"csrc/cache_kernels.cu"
,
...
@@ -468,6 +509,7 @@ setuptools.setup(
...
@@ -468,6 +509,7 @@ setuptools.setup(
python_requires
=
">=3.8"
,
python_requires
=
">=3.8"
,
install_requires
=
get_requirements
(),
install_requires
=
get_requirements
(),
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
}
if
not
_is_neuron
()
else
{},
cmdclass
=
{
"build_ext"
:
build_ext
}
if
not
_is_neuron
()
else
{},
distclass
=
BinaryDistribution
,
package_data
=
package_data
,
package_data
=
package_data
,
)
)
tests/kernels/test_prefix_prefill.py
View file @
2daf23ab
...
@@ -3,7 +3,7 @@ import pytest
...
@@ -3,7 +3,7 @@ import pytest
import
time
import
time
import
torch
import
torch
from
vllm.model_executor.layers.
triton_kernel
.prefix_prefill
import
(
from
vllm.model_executor.layers.
attention.ops
.prefix_prefill
import
(
context_attention_fwd
)
context_attention_fwd
)
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
...
...
vllm/__init__.py
View file @
2daf23ab
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
from
vllm.engine.llm_engine
import
LLMEngine
def
_configure_system
():
from
vllm.engine.ray_utils
import
initialize_cluster
import
os
from
vllm.entrypoints.llm
import
LLM
import
sys
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
SamplingParams
# Importing flash-attn.
thirdparty_files
=
os
.
path
.
join
(
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
)),
"thirdparty_files"
)
sys
.
path
.
insert
(
0
,
thirdparty_files
)
_configure_system
()
# Delete configuration function.
del
_configure_system
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
# noqa: E402
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# noqa: E402
from
vllm.engine.llm_engine
import
LLMEngine
# noqa: E402
from
vllm.engine.ray_utils
import
initialize_cluster
# noqa: E402
from
vllm.entrypoints.llm
import
LLM
# noqa: E402
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
# noqa: E402
from
vllm.sampling_params
import
SamplingParams
# noqa: E402
__version__
=
"0.3.3"
__version__
=
"0.3.3"
...
...
vllm/model_executor/layers/attention/__init__.py
0 → 100644
View file @
2daf23ab
from
vllm.model_executor.layers.attention.attention
import
Attention
__all__
=
[
"Attention"
,
]
vllm/model_executor/layers/attention/attention.py
0 → 100644
View file @
2daf23ab
"""Attention layer."""
from
typing
import
List
,
Optional
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.utils
import
is_hip
class
Attention
(
nn
.
Module
):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
if
(
not
is_hip
()
and
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
and
torch
.
get_default_dtype
()
in
(
torch
.
float16
,
torch
.
bfloat16
)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
from
vllm.model_executor.layers.attention.backends.flash_attn
import
FlashAttentionBackend
self
.
backend
=
FlashAttentionBackend
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
)
else
:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
# Or FP32 on any GPU.
from
vllm.model_executor.layers.attention.backends.xformers
import
XFormersBackend
self
.
backend
=
XFormersBackend
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
Optional
[
torch
.
Tensor
],
value_cache
:
Optional
[
torch
.
Tensor
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
return
self
.
backend
.
forward
(
query
,
key
,
value
,
key_cache
,
value_cache
,
input_metadata
)
vllm/model_executor/layers/
triton_kernel
/__init__.py
→
vllm/model_executor/layers/
attention/backends
/__init__.py
View file @
2daf23ab
File moved
vllm/model_executor/layers/attention/backends/flash_attn.py
0 → 100644
View file @
2daf23ab
"""Attention layer with Flash and PagedAttention."""
from
typing
import
List
,
Optional
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from
flash_attn
import
flash_attn_func
import
torch
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention.ops.paged_attn
import
(
PagedAttentionImpl
)
class
FlashAttentionBackend
:
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttentionImpl
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
self
.
sliding_window
=
((
self
.
sliding_window
,
self
.
sliding_window
)
if
self
.
sliding_window
is
not
None
else
(
-
1
,
-
1
))
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
Optional
[
torch
.
Tensor
],
value_cache
:
Optional
[
torch
.
Tensor
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if
key_cache
is
not
None
and
value_cache
is
not
None
:
PagedAttentionImpl
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
)
if
input_metadata
.
is_prompt
:
# Prompt run.
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# normal attention
query
=
query
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
key
=
key
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
value
=
value
.
unflatten
(
0
,
(
batch_size
,
seq_len
))
output
=
flash_attn_func
(
query
,
key
,
value
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
# prefix-enabled attention
output
=
PagedAttentionImpl
.
forward_prefix
(
query
,
key
,
value
,
key_cache
,
value_cache
,
input_metadata
,
self
.
num_heads
,
self
.
num_kv_heads
,
self
.
alibi_slopes
,
)
else
:
# Decoding run.
output
=
PagedAttentionImpl
.
forward_decode
(
query
,
key_cache
,
value_cache
,
input_metadata
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
)
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
vllm/model_executor/layers/attention.py
→
vllm/model_executor/layers/attention
/backends/xformers
.py
View file @
2daf23ab
"""Multi-head attention."""
"""Attention layer with xFormers and PagedAttention."""
import
importlib
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
importlib
import
torch
import
torch
import
torch.nn
as
nn
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
(
BlockDiagonalCausalMask
,
from
xformers.ops.fmha.attn_bias
import
(
BlockDiagonalCausalMask
,
LowerTriangularMaskWithTensorBias
)
LowerTriangularMaskWithTensorBias
)
from
vllm._C
import
ops
from
vllm._C
import
cache_ops
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.
triton_kernel.prefix_prefill
import
(
from
vllm.model_executor.layers.
attention.ops.paged_attn
import
(
context_a
ttention
_fwd
)
PagedA
ttention
Impl
)
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
class
PagedAttention
(
nn
.
Module
):
"""MHA/MQA/GQA layer with PagedAttention.
This class takes query, key, and value tensors as input. The input tensors
class
XFormersBackend
:
can either contain prompt tokens or generation tokens.
The class does the following:
1. Reshape and store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention using either
xformers or the PagedAttention custom op.
3. Return the output tensor.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -42,7 +24,6 @@ class PagedAttention(nn.Module):
...
@@ -42,7 +24,6 @@ class PagedAttention(nn.Module):
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
@@ -50,48 +31,17 @@ class PagedAttention(nn.Module):
...
@@ -50,48 +31,17 @@ class PagedAttention(nn.Module):
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
register_buffer
(
"
alibi_slopes
"
,
alibi_slopes
,
persistent
=
False
)
self
.
alibi_slopes
=
alibi_slopes
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
suppored_head_sizes
=
PagedAttentionImpl
.
get_supported_head_sizes
()
if
head_size
not
in
suppored_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
if
self
.
head_size
not
in
_SUPPORTED_HEAD_SIZES
:
self
.
use_ref_attention
=
_check_use_ref_attention
()
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
f
"Supported head sizes:
{
_SUPPORTED_HEAD_SIZES
}
."
)
self
.
use_ref_attention
=
self
.
check_use_ref_attention
()
def
check_use_ref_attention
(
self
)
->
bool
:
if
not
is_hip
():
return
False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
None
def
ref_masked_attention
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
seq_len
,
_
,
_
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
query
.
dtype
,
device
=
query
.
device
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
attn_weights
=
self
.
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
def
forward
(
def
forward
(
self
,
self
,
...
@@ -102,7 +52,7 @@ class PagedAttention(nn.Module):
...
@@ -102,7 +52,7 @@ class PagedAttention(nn.Module):
value_cache
:
Optional
[
torch
.
Tensor
],
value_cache
:
Optional
[
torch
.
Tensor
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""PagedAttention
forward pass
.
"""
Forward pass with xFormers and
PagedAttention.
Args:
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
query: shape = [batch_size, seq_len, num_heads * head_size]
...
@@ -127,19 +77,14 @@ class PagedAttention(nn.Module):
...
@@ -127,19 +77,14 @@ class PagedAttention(nn.Module):
# vectors will not be cached. This happens during the initial memory
# vectors will not be cached. This happens during the initial memory
# profiling run.
# profiling run.
if
key_cache
is
not
None
and
value_cache
is
not
None
:
if
key_cache
is
not
None
and
value_cache
is
not
None
:
cache_ops
.
reshape_and_cache
(
PagedAttentionImpl
.
reshape_and_cache
(
key
,
value
,
key_cache
,
key
,
value_cache
,
input_metadata
)
value
,
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
.
flatten
(),
input_metadata
.
kv_cache_dtype
,
)
if
input_metadata
.
is_prompt
:
if
input_metadata
.
is_prompt
:
#
normal attention
#
Prompt run.
if
(
key_cache
is
None
or
value_cache
is
None
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
or
input_metadata
.
block_tables
.
numel
()
==
0
):
# normal attention
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# project the key and value tensors to the desired number of
...
@@ -175,13 +120,19 @@ class PagedAttention(nn.Module):
...
@@ -175,13 +120,19 @@ class PagedAttention(nn.Module):
seq_len
,
query
.
dtype
)
seq_len
,
query
.
dtype
)
if
self
.
use_ref_attention
:
if
self
.
use_ref_attention
:
output
=
self
.
ref_masked_attention
(
output
=
_
ref_masked_attention
(
query
,
query
,
key
,
key
,
value
,
value
,
self
.
num_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
scale
,
)
)
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
# Using view got RuntimeError: view size is not compatible
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
# TODO(woosuk): Too many view operations. Let's try to reduce
# TODO(woosuk): Too many view operations. Let's try to reduce
...
@@ -206,27 +157,21 @@ class PagedAttention(nn.Module):
...
@@ -206,27 +157,21 @@ class PagedAttention(nn.Module):
(
is_hip
())
else
None
,
(
is_hip
())
else
None
,
)
)
output
=
out
.
view_as
(
query
)
output
=
out
.
view_as
(
query
)
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
output
=
torch
.
empty_like
(
query
)
output
=
PagedAttentionImpl
.
forward_prefix
(
context_attention_fwd
(
query
,
query
,
key
,
key
,
value
,
value
,
output
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
input_metadata
.
block_tables
,
# [BS, max_block_per_request]
input_metadata
,
input_metadata
.
start_loc
,
self
.
alibi_slopes
,
input_metadata
.
prompt_lens
,
input_metadata
.
context_lens
,
input_metadata
.
max_seq_len
,
getattr
(
self
,
"alibi_slopes"
,
None
),
)
)
else
:
else
:
# Decoding run.
# Decoding run.
output
=
_p
aged
_a
ttention
(
output
=
P
aged
A
ttention
Impl
.
forward_decode
(
query
,
query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
...
@@ -274,76 +219,37 @@ def _make_alibi_bias(
...
@@ -274,76 +219,37 @@ def _make_alibi_bias(
return
attn_bias
return
attn_bias
def
_paged_attention
(
def
_check_use_ref_attention
()
->
bool
:
if
not
is_hip
():
return
False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
None
def
_ref_masked_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
_cache
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
_cache
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
num_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
block_size
=
value_cache
.
shape
[
3
]
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
seq_len
,
_
,
_
=
query
.
shape
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
_PARTITION_SIZE
)
seq_len
,
# NOTE(woosuk): We use a simple heuristic to decide whether to use
dtype
=
query
.
dtype
,
# PagedAttention V1 or V2. If the number of partitions is 1, we use
device
=
query
.
device
),
# V1 to avoid the overhead of reduction. Also, if the number of
diagonal
=
1
)
# sequences or heads is large, we use V1 since there is enough work
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
# to parallelize.
# TODO(woosuk): Tune this heuristic.
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
attn_weights
=
attn_weights
+
attn_mask
.
float
()
use_v1
=
input_metadata
.
max_context_len
<=
8192
and
(
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
if
use_v1
:
return
out
# Run PagedAttention V1.
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
input_metadata
.
kv_cache_dtype
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
input_metadata
.
kv_cache_dtype
,
)
return
output
vllm/model_executor/layers/attention/ops/__init__.py
0 → 100644
View file @
2daf23ab
vllm/model_executor/layers/attention/ops/paged_attn.py
0 → 100644
View file @
2daf23ab
from
typing
import
List
,
Optional
import
torch
from
vllm._C
import
cache_ops
from
vllm._C
import
ops
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention.ops.prefix_prefill
import
(
context_attention_fwd
)
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
class
PagedAttentionImpl
:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
256
]
@
staticmethod
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
.
flatten
(),
input_metadata
.
kv_cache_dtype
,
)
@
staticmethod
def
forward_decode
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
(
input_metadata
.
max_context_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
input_metadata
.
max_context_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
if
use_v1
:
# Run PagedAttention V1.
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
input_metadata
.
kv_cache_dtype
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
alibi_slopes
,
input_metadata
.
kv_cache_dtype
,
)
return
output
@
staticmethod
def
forward_prefix
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
context_attention_fwd
(
query
,
key
,
value
,
output
,
key_cache
,
value_cache
,
input_metadata
.
block_tables
,
# [BS, max_block_per_request]
input_metadata
.
start_loc
,
input_metadata
.
prompt_lens
,
input_metadata
.
context_lens
,
input_metadata
.
max_seq_len
,
alibi_slopes
,
)
return
output
vllm/model_executor/layers/
triton_kernel
/prefix_prefill.py
→
vllm/model_executor/layers/
attention/ops
/prefix_prefill.py
View file @
2daf23ab
File moved
vllm/model_executor/models/baichuan.py
View file @
2daf23ab
...
@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
...
@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -151,10 +151,10 @@ class BaiChuanAttention(nn.Module):
...
@@ -151,10 +151,10 @@ class BaiChuanAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scaling
,
scaling
,
alibi_slopes
=
alibi_slopes
)
alibi_slopes
=
alibi_slopes
)
else
:
else
:
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -163,8 +163,7 @@ class BaiChuanAttention(nn.Module):
...
@@ -163,8 +163,7 @@ class BaiChuanAttention(nn.Module):
base
=
self
.
rope_theta
,
base
=
self
.
rope_theta
,
)
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
self
.
scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/bloom.py
View file @
2daf23ab
...
@@ -25,7 +25,7 @@ from transformers import BloomConfig
...
@@ -25,7 +25,7 @@ from transformers import BloomConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -107,10 +107,10 @@ class BloomAttention(nn.Module):
...
@@ -107,10 +107,10 @@ class BloomAttention(nn.Module):
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scaling
,
scaling
,
alibi_slopes
=
alibi_slopes
)
alibi_slopes
=
alibi_slopes
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/chatglm.py
View file @
2daf23ab
...
@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
...
@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -87,7 +87,7 @@ class GLMAttention(nn.Module):
...
@@ -87,7 +87,7 @@ class GLMAttention(nn.Module):
base
=
10000
*
rope_ratio
,
base
=
10000
*
rope_ratio
,
is_neox_style
=
False
,
is_neox_style
=
False
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
...
...
vllm/model_executor/models/deepseek.py
View file @
2daf23ab
...
@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
...
@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
...
@@ -229,10 +229,10 @@ class DeepseekAttention(nn.Module):
...
@@ -229,10 +229,10 @@ class DeepseekAttention(nn.Module):
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/falcon.py
View file @
2daf23ab
...
@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
...
@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -150,10 +150,10 @@ class FalconAttention(nn.Module):
...
@@ -150,10 +150,10 @@ class FalconAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
inv_norm_factor
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
)
elif
self
.
use_alibi
:
elif
self
.
use_alibi
:
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_start
=
tp_rank
*
self
.
num_heads
...
@@ -161,16 +161,16 @@ class FalconAttention(nn.Module):
...
@@ -161,16 +161,16 @@ class FalconAttention(nn.Module):
alibi_slopes
=
(
_get_alibi_slopes
(
self
.
total_num_heads
)
*
alibi_slopes
=
(
_get_alibi_slopes
(
self
.
total_num_heads
)
*
self
.
inv_norm_factor
)
self
.
inv_norm_factor
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
inv_norm_factor
,
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
alibi_slopes
=
alibi_slopes
)
alibi_slopes
=
alibi_slopes
)
else
:
else
:
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scale
=
self
.
inv_norm_factor
,
scale
=
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/gemma.py
View file @
2daf23ab
...
@@ -23,7 +23,7 @@ from transformers import GemmaConfig
...
@@ -23,7 +23,7 @@ from transformers import GemmaConfig
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -123,10 +123,10 @@ class GemmaAttention(nn.Module):
...
@@ -123,10 +123,10 @@ class GemmaAttention(nn.Module):
base
=
self
.
rope_theta
,
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
is_neox_style
=
True
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/gpt2.py
View file @
2daf23ab
...
@@ -25,7 +25,7 @@ from transformers import GPT2Config
...
@@ -25,7 +25,7 @@ from transformers import GPT2Config
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -73,9 +73,7 @@ class GPT2Attention(nn.Module):
...
@@ -73,9 +73,7 @@ class GPT2Attention(nn.Module):
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
self
.
head_dim
,
scale
=
self
.
scale
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
2daf23ab
...
@@ -26,7 +26,7 @@ from transformers import GPTBigCodeConfig
...
@@ -26,7 +26,7 @@ from transformers import GPTBigCodeConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
Paged
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -85,10 +85,10 @@ class GPTBigCodeAttention(nn.Module):
...
@@ -85,10 +85,10 @@ class GPTBigCodeAttention(nn.Module):
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scale
=
self
.
scale
,
scale
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment