Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c1fb507
Unverified
Commit
8c1fb507
authored
Nov 19, 2024
by
Mengqing Cao
Committed by
GitHub
Nov 19, 2024
Browse files
[Platform][Refactor] Extract func `get_default_attn_backend` to `Platform` (#10358)
Signed-off-by:
Mengqing Cao
<
cmq0113@163.com
>
parent
7eb719df
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
99 additions
and
69 deletions
+99
-69
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+11
-8
vllm/attention/selector.py
vllm/attention/selector.py
+6
-50
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+1
-1
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-1
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+2
-2
vllm/platforms/__init__.py
vllm/platforms/__init__.py
+1
-0
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+9
-1
vllm/platforms/hpu.py
vllm/platforms/hpu.py
+5
-1
vllm/platforms/interface.py
vllm/platforms/interface.py
+19
-0
vllm/platforms/openvino.py
vllm/platforms/openvino.py
+7
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+13
-1
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+11
-1
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+11
-1
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+2
-1
No files found.
tests/kernels/test_attention_selector.py
View file @
8c1fb507
...
...
@@ -5,6 +5,7 @@ import torch
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
which_attn_to_use
from
vllm.platforms
import
cpu
,
cuda
,
openvino
,
rocm
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
...
...
@@ -19,24 +20,26 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable
(
monkeypatch
,
name
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform
.is_cpu
"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.current_platform"
,
cpu
.
CpuPlatform
()
):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform
.is_rocm
"
,
r
eturn_value
=
True
):
with
patch
(
"vllm.attention.selector.current_platform"
,
r
ocm
.
RocmPlatform
()
):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.current_platform
.is_openvino
"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.current_platform"
,
openvino
.
OpenVinoPlatform
()
):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"OPENVINO"
else
:
with
patch
(
"vllm.attention.selector.current_platform"
,
cuda
.
CudaPlatform
()):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
name
...
...
vllm/attention/selector.py
View file @
8c1fb507
import
enum
import
os
from
contextlib
import
contextmanager
from
functools
import
lru_cache
...
...
@@ -9,26 +8,12 @@ import torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
STR_BACKEND_ENV_VAR
logger
=
init_logger
(
__name__
)
class
_Backend
(
enum
.
Enum
):
FLASH_ATTN
=
enum
.
auto
()
FLASH_ATTN_VLLM_V1
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
OPENVINO
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
HPU_ATTN
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
IPEX
=
enum
.
auto
()
NO_ATTENTION
=
enum
.
auto
()
def
backend_name_to_enum
(
backend_name
:
str
)
->
_Backend
:
assert
backend_name
is
not
None
...
...
@@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int,
if
backend_by_env_var
is
not
None
:
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
current_platform
.
is_cpu
():
if
selected_backend
!=
_Backend
.
TORCH_SDPA
:
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
if
current_platform
.
is_openvino
():
if
selected_backend
!=
_Backend
.
OPENVINO
:
logger
.
info
(
"Cannot use %s backend on OpenVINO."
,
selected_backend
)
return
_Backend
.
OPENVINO
if
current_platform
.
is_xpu
():
if
selected_backend
!=
_Backend
.
IPEX
:
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
return
_Backend
.
IPEX
if
current_platform
.
is_tpu
():
if
selected_backend
!=
_Backend
.
PALLAS
:
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
return
_Backend
.
PALLAS
if
current_platform
.
is_rocm
():
# AMD GPUs.
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
not
current_platform
.
has_device_capability
(
90
):
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
return
_Backend
.
ROCM_FLASH
if
current_platform
.
is_hpu
():
return
_Backend
.
HPU_ATTN
# get device-specific default attn_backend
default_backend
=
current_platform
.
get_default_attn_backend
(
selected_backend
)
if
default_backend
is
not
None
:
return
default_backend
if
use_v1
:
return
_Backend
.
FLASH_ATTN_VLLM_V1
...
...
vllm/model_executor/models/molmo.py
View file @
8c1fb507
...
...
@@ -13,7 +13,6 @@ from torch.nn import functional as F
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention.selector
import
_Backend
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
...
...
@@ -38,6 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
vllm.transformers_utils.processor
import
get_processor
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
8c1fb507
...
...
@@ -39,7 +39,6 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images
,
make_batched_videos
,
smart_resize
)
from
vllm.attention
import
AttentionMetadata
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
,
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
...
...
@@ -65,6 +64,7 @@ from vllm.multimodal.image import cached_get_image_processor
from
vllm.multimodal.inputs
import
(
MultiModalData
,
MultiModalDataDict
,
MultiModalKwargs
)
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
,
SequenceData
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.processor
import
cached_get_processor
...
...
vllm/model_executor/models/utils.py
View file @
8c1fb507
...
...
@@ -9,13 +9,13 @@ from torch.func import functional_call
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.attention.selector
import
(
_Backend
,
backend_name_to_enum
,
from
vllm.attention.selector
import
(
backend_name_to_enum
,
get_global_forced_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal
import
MultiModalPlaceholderMap
,
NestedTensors
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
...
...
vllm/platforms/__init__.py
View file @
8c1fb507
from
.interface
import
_Backend
# noqa: F401
from
.interface
import
Platform
,
PlatformEnum
,
UnspecifiedPlatform
current_platform
:
Platform
...
...
vllm/platforms/cpu.py
View file @
8c1fb507
...
...
@@ -5,7 +5,9 @@ import torch
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
...
...
@@ -22,6 +24,12 @@ class CpuPlatform(Platform):
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
return
"cpu"
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
if
selected_backend
!=
_Backend
.
TORCH_SDPA
:
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
return
psutil
.
virtual_memory
().
total
...
...
vllm/platforms/hpu.py
View file @
8c1fb507
import
torch
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
class
HpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
HPU
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
return
_Backend
.
HPU_ATTN
@
staticmethod
def
inference_mode
():
return
torch
.
no_grad
()
vllm/platforms/interface.py
View file @
8c1fb507
...
...
@@ -11,6 +11,20 @@ else:
VllmConfig
=
None
class
_Backend
(
enum
.
Enum
):
FLASH_ATTN
=
enum
.
auto
()
FLASH_ATTN_VLLM_V1
=
enum
.
auto
()
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
OPENVINO
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
HPU_ATTN
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
IPEX
=
enum
.
auto
()
NO_ATTENTION
=
enum
.
auto
()
class
PlatformEnum
(
enum
.
Enum
):
CUDA
=
enum
.
auto
()
ROCM
=
enum
.
auto
()
...
...
@@ -71,6 +85,11 @@ class Platform:
"""Stateless version of :func:`torch.cuda.is_available`."""
return
self
.
_enum
in
(
PlatformEnum
.
CUDA
,
PlatformEnum
.
ROCM
)
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
):
"""Get the default attention backend of a device."""
return
None
@
classmethod
def
get_device_capability
(
cls
,
...
...
vllm/platforms/openvino.py
View file @
8c1fb507
...
...
@@ -3,7 +3,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
logger
=
init_logger
(
__name__
)
...
...
@@ -11,6 +11,12 @@ logger = init_logger(__name__)
class
OpenVinoPlatform
(
Platform
):
_enum
=
PlatformEnum
.
OPENVINO
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
if
selected_backend
!=
_Backend
.
OPENVINO
:
logger
.
info
(
"Cannot use %s backend on OpenVINO."
,
selected_backend
)
return
_Backend
.
OPENVINO
@
classmethod
def
get_device_name
(
self
,
device_id
:
int
=
0
)
->
str
:
return
"openvino"
...
...
vllm/platforms/rocm.py
View file @
8c1fb507
...
...
@@ -5,7 +5,7 @@ import torch
from
vllm.logger
import
init_logger
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
logger
=
init_logger
(
__name__
)
...
...
@@ -19,6 +19,18 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
not
cls
.
has_device_capability
(
90
):
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
return
_Backend
.
ROCM_FLASH
@
classmethod
@
lru_cache
(
maxsize
=
8
)
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
DeviceCapability
:
...
...
vllm/platforms/tpu.py
View file @
8c1fb507
...
...
@@ -3,17 +3,27 @@ from typing import TYPE_CHECKING
import
torch
from
.interface
import
Platform
,
PlatformEnum
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
else
:
VllmConfig
=
None
logger
=
init_logger
(
__name__
)
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
if
selected_backend
!=
_Backend
.
PALLAS
:
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
return
_Backend
.
PALLAS
@
classmethod
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
raise
NotImplementedError
...
...
vllm/platforms/xpu.py
View file @
8c1fb507
import
torch
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
from
vllm.logger
import
init_logger
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
logger
=
init_logger
(
__name__
)
class
XPUPlatform
(
Platform
):
_enum
=
PlatformEnum
.
XPU
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
if
selected_backend
!=
_Backend
.
IPEX
:
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
return
_Backend
.
IPEX
@
staticmethod
def
get_device_capability
(
device_id
:
int
=
0
)
->
DeviceCapability
:
major
,
minor
,
*
_
=
torch
.
xpu
.
get_device_capability
(
...
...
vllm/worker/enc_dec_model_runner.py
View file @
8c1fb507
...
...
@@ -8,7 +8,7 @@ import torch.distributed
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
)
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.attention.selector
import
(
_Backend
,
get_env_variable_attn_backend
,
from
vllm.attention.selector
import
(
get_env_variable_attn_backend
,
get_global_forced_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
...
...
@@ -18,6 +18,7 @@ from vllm.model_executor import SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
MultiModalRegistry
)
from
vllm.platforms
import
_Backend
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SequenceGroupMetadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment