Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
496e991d
Unverified
Commit
496e991d
authored
Oct 21, 2024
by
Thomas Parnell
Committed by
GitHub
Oct 21, 2024
Browse files
[Doc] Consistent naming of attention backends (#9498)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
696b01af
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
23 additions
and
19 deletions
+23
-19
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+1
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+1
-1
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+1
-1
vllm/attention/backends/openvino.py
vllm/attention/backends/openvino.py
+1
-1
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+4
-0
vllm/attention/backends/placeholder_attn.py
vllm/attention/backends/placeholder_attn.py
+1
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-1
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+1
-1
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+6
-6
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+1
-1
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+1
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+1
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-1
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+2
-2
No files found.
vllm/attention/backends/flash_attn.py
View file @
496e991d
...
...
@@ -32,7 +32,7 @@ class FlashAttentionBackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
flash-attn
"
return
"
FLASH_ATTN
"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashAttentionImpl"
]:
...
...
vllm/attention/backends/flashinfer.py
View file @
496e991d
...
...
@@ -40,7 +40,7 @@ class FlashInferBackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
flashinfer
"
return
"
FLASHINFER
"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashInferImpl"
]:
...
...
vllm/attention/backends/ipex_attn.py
View file @
496e991d
...
...
@@ -19,7 +19,7 @@ class IpexAttnBackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
ipex-attn
"
return
"
IPEX
"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"IpexAttnBackendImpl"
]:
...
...
vllm/attention/backends/openvino.py
View file @
496e991d
...
...
@@ -38,7 +38,7 @@ class OpenVINOAttentionBackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
openvino
"
return
"
OPENVINO
"
@
staticmethod
def
get_impl_cls
():
...
...
vllm/attention/backends/pallas.py
View file @
496e991d
...
...
@@ -11,6 +11,10 @@ from vllm.attention.backends.utils import CommonAttentionState
class
PallasAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"PALLAS"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"PallasAttentionBackendImpl"
]:
return
PallasAttentionBackendImpl
...
...
vllm/attention/backends/placeholder_attn.py
View file @
496e991d
...
...
@@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
placeholder-attn
"
return
"
NO_ATTENTION
"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"PlaceholderAttentionImpl"
]:
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
496e991d
...
...
@@ -28,7 +28,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
rocm-flash-attn
"
return
"
ROCM_FLASH
"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"ROCmFlashAttentionImpl"
]:
...
...
vllm/attention/backends/torch_sdpa.py
View file @
496e991d
...
...
@@ -25,7 +25,7 @@ class TorchSDPABackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
torch-sdpa
"
return
"
TORCH_SDPA
"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"TorchSDPABackendImpl"
]:
...
...
vllm/attention/backends/utils.py
View file @
496e991d
...
...
@@ -317,8 +317,8 @@ class CommonAttentionState(AttentionState):
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"
xformers
"
,
\
f
"Expected attn_backend name to be '
xformers
', but "
\
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"
XFORMERS
"
,
\
f
"Expected attn_backend name to be '
XFORMERS
', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_update_captured_metadata_for_enc_dec_model
(
batch_size
=
batch_size
,
attn_metadata
=
attn_metadata
)
...
...
@@ -337,8 +337,8 @@ class CommonAttentionState(AttentionState):
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"
xformers
"
,
\
f
"Expected attn_backend name to be '
xformers
', but "
\
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"
XFORMERS
"
,
\
f
"Expected attn_backend name to be '
XFORMERS
', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_add_additonal_input_buffers_for_enc_dec_model
(
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
...
...
@@ -356,8 +356,8 @@ class CommonAttentionState(AttentionState):
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"
xformers
"
,
\
f
"Expected attn_backend name to be '
xformers
', but "
\
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"
XFORMERS
"
,
\
f
"Expected attn_backend name to be '
XFORMERS
', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_prepare_input_buffers_for_enc_dec_model
(
attn_metadata
,
input_buffers
)
...
...
vllm/attention/backends/xformers.py
View file @
496e991d
...
...
@@ -24,7 +24,7 @@ class XFormersBackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
xformers
"
return
"
XFORMERS
"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"XFormersImpl"
]:
...
...
vllm/spec_decode/draft_model_runner.py
View file @
496e991d
...
...
@@ -179,7 +179,7 @@ class TP1DraftModelRunner(ModelRunner):
return
False
# TODO: Add support for other attn backends
if
self
.
attn_backend
.
get_name
()
!=
"
flash-attn
"
:
if
self
.
attn_backend
.
get_name
()
!=
"
FLASH_ATTN
"
:
return
False
# TODO: Add support for LORA
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
496e991d
...
...
@@ -184,7 +184,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
not
disable_mqa_scorer
:
if
scorer_worker
.
model_runner
.
attn_backend
.
get_name
(
)
!=
"
flash-attn
"
:
)
!=
"
FLASH_ATTN
"
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
...
...
vllm/worker/model_runner.py
View file @
496e991d
...
...
@@ -1855,7 +1855,7 @@ class CUDAGraphRunner(nn.Module):
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
,
non_blocking
=
True
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
if
self
.
backend_name
!=
"
placeholder-attn
"
:
if
self
.
backend_name
!=
"
NO_ATTENTION
"
:
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
non_blocking
=
True
)
...
...
vllm/worker/multi_step_model_runner.py
View file @
496e991d
...
...
@@ -29,8 +29,8 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
MULTI_STEP_ATTENTION_BACKENDS
=
[
"
flash-attn"
,
"rocm-flash-attn"
,
"flashinfer
"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
=
[
"
flash-attn
"
]
MULTI_STEP_ATTENTION_BACKENDS
=
[
"
FLASH_ATTN"
,
"ROCM_FLASH"
,
"FLASHINFER
"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
=
[
"
FLASH_ATTN
"
]
def
_get_supported_attention_backends
(
chunked_prefill_enabled
:
bool
)
\
->
List
[
str
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment