Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5510cf0e
Unverified
Commit
5510cf0e
authored
May 08, 2024
by
Woosuk Kwon
Committed by
GitHub
May 08, 2024
Browse files
[Misc] Add `get_name` method to attention backends (#4685)
parent
0f9a6e3d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
30 additions
and
12 deletions
+30
-12
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+5
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+4
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+7
-9
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+4
-0
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+4
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+4
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-3
No files found.
vllm/attention/backends/abstract.py
View file @
5510cf0e
...
...
@@ -9,6 +9,11 @@ import torch
class
AttentionBackend
(
ABC
):
"""Abstract class for attention backends."""
@
staticmethod
@
abstractmethod
def
get_name
()
->
str
:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
get_impl_cls
()
->
Type
[
"AttentionImpl"
]:
...
...
vllm/attention/backends/flash_attn.py
View file @
5510cf0e
...
...
@@ -19,6 +19,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class
FlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"flash-attn"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashAttentionImpl"
]:
return
FlashAttentionImpl
...
...
vllm/attention/backends/flashinfer.py
View file @
5510cf0e
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
import
flashinfer
from
flash_attn
import
flash_attn_varlen_func
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
except
ImportError
:
flashinfer
=
None
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
import
flashinfer
import
torch
from
flash_attn
import
flash_attn_varlen_func
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
...
...
@@ -20,6 +14,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
class
FlashInferBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"flashinfer"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashInferImpl"
]:
return
FlashInferImpl
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
5510cf0e
...
...
@@ -17,6 +17,10 @@ logger = init_logger(__name__)
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"rocm-flash-attn"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"ROCmFlashAttentionImpl"
]:
return
ROCmFlashAttentionImpl
...
...
vllm/attention/backends/torch_sdpa.py
View file @
5510cf0e
...
...
@@ -15,6 +15,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class
TorchSDPABackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"torch-sdpa"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"TorchSDPABackendImpl"
]:
return
TorchSDPABackendImpl
...
...
vllm/attention/backends/xformers.py
View file @
5510cf0e
...
...
@@ -20,6 +20,10 @@ logger = init_logger(__name__)
class
XFormersBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"xformers"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"XFormersImpl"
]:
return
XFormersImpl
...
...
vllm/worker/model_runner.py
View file @
5510cf0e
...
...
@@ -9,7 +9,6 @@ import torch.nn as nn
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataPerStage
,
get_attn_backend
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
,
with_pynccl_for_all_reduce
...
...
@@ -395,7 +394,7 @@ class ModelRunner:
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
if
self
.
attn_backend
is
F
lash
I
nfer
Backend
:
if
self
.
attn_backend
.
get_name
()
==
"f
lash
i
nfer
"
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
use_cuda_graph
=
False
,
...
...
@@ -556,7 +555,7 @@ class ModelRunner:
device
=
self
.
device
,
)
if
self
.
attn_backend
is
F
lash
I
nfer
Backend
:
if
self
.
attn_backend
.
get_name
()
==
"f
lash
i
nfer
"
:
if
not
hasattr
(
self
,
"flashinfer_workspace_buffer"
):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment