Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04668ebe
Unverified
Commit
04668ebe
authored
Nov 24, 2024
by
Isotr0py
Committed by
GitHub
Nov 23, 2024
Browse files
[Bugfix] Avoid import AttentionMetadata explicitly in Mllama (#10593)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
651f6c31
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
21 additions
and
11 deletions
+21
-11
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+5
-0
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+7
-7
vllm/platforms/openvino.py
vllm/platforms/openvino.py
+6
-2
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+1
-1
No files found.
vllm/attention/backends/blocksparse_attn.py
View file @
04668ebe
...
@@ -87,6 +87,11 @@ class BlocksparseParams:
...
@@ -87,6 +87,11 @@ class BlocksparseParams:
class
BlocksparseFlashAttentionBackend
(
AttentionBackend
):
class
BlocksparseFlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
# For attention layer compatibility
return
"FLASH_ATTN"
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
Type
[
"BlocksparseFlashAttentionImpl"
]:
def
get_impl_cls
()
->
Type
[
"BlocksparseFlashAttentionImpl"
]:
return
BlocksparseFlashAttentionImpl
return
BlocksparseFlashAttentionImpl
...
...
vllm/attention/layer.py
View file @
04668ebe
...
@@ -6,7 +6,7 @@ import torch.nn as nn
...
@@ -6,7 +6,7 @@ import torch.nn as nn
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -98,6 +98,7 @@ class Attention(nn.Module):
...
@@ -98,6 +98,7 @@ class Attention(nn.Module):
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
)
blocksparse_params
,
logits_soft_cap
)
self
.
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# torch.compile works by registering the attention as one giant
...
...
vllm/model_executor/models/mllama.py
View file @
04668ebe
...
@@ -32,9 +32,8 @@ from transformers.models.mllama.processing_mllama import (
...
@@ -32,9 +32,8 @@ from transformers.models.mllama.processing_mllama import (
import
vllm.distributed.parallel_state
as
ps
import
vllm.distributed.parallel_state
as
ps
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.xformers
import
XFormersMetadata
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DummyData
,
EncoderDecoderInputs
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DummyData
,
EncoderDecoderInputs
,
...
@@ -828,7 +827,8 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -828,7 +827,8 @@ class MllamaTextCrossAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Skip writing kv-cache for the initial profiling run.
# Skip writing kv-cache for the initial profiling run.
if
len
(
kv_cache
.
shape
)
>
1
:
if
len
(
kv_cache
.
shape
)
>
1
:
if
isinstance
(
attn_metadata
,
FlashAttentionMetadata
):
if
self
.
attn
.
backend
in
(
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
):
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
...
@@ -842,7 +842,7 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -842,7 +842,7 @@ class MllamaTextCrossAttention(nn.Module):
1.0
,
1.0
,
1.0
,
1.0
,
)
)
elif
isinstance
(
attn_metadata
,
XFormersMetadata
):
elif
self
.
attn
.
backend
in
(
_Backend
.
XFORMERS
,
_Backend
.
TORCH_SDPA
):
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
kv_cache
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
...
@@ -852,9 +852,9 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -852,9 +852,9 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata
.
cross_slot_mapping
,
"auto"
,
1.0
,
1.0
)
attn_metadata
.
cross_slot_mapping
,
"auto"
,
1.0
,
1.0
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported Attention
Metadata
{
type
(
attn_metadata
)
}
"
f
"Unsupported Attention
backend
{
self
.
attn
.
backend
}
"
f
"class
found. Expected the Attention
Metadata to
"
"enum
found. Expected the Attention
backend to be
"
f
"be either XFormersMetadata or FlashAttentionMetadata
."
)
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA
."
)
# We have to call torch.sdpa for prefill when using a
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# custom cross-attention mask. Because the mask is not a
...
...
vllm/platforms/openvino.py
View file @
04668ebe
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
openvino
as
ov
import
openvino.properties.hint
as
hints
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -16,6 +14,12 @@ else:
...
@@ -16,6 +14,12 @@ else:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
try
:
import
openvino
as
ov
import
openvino.properties.hint
as
hints
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import OpenVINO with %r"
,
e
)
class
OpenVinoPlatform
(
Platform
):
class
OpenVinoPlatform
(
Platform
):
_enum
=
PlatformEnum
.
OPENVINO
_enum
=
PlatformEnum
.
OPENVINO
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
04668ebe
...
@@ -19,7 +19,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -19,7 +19,7 @@ class FlashAttentionBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
return
"
flash-attn-vllm-v
1"
return
"
FLASH_ATTN_VLLM_V
1"
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashAttentionImpl"
]:
def
get_impl_cls
()
->
Type
[
"FlashAttentionImpl"
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment