Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4d2dc507
Unverified
Commit
4d2dc507
authored
Aug 13, 2024
by
youkaichao
Committed by
GitHub
Aug 13, 2024
Browse files
[hardware] unify usage of is_tpu to current_platform.is_tpu() (#7102)
parent
7025b11d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
29 additions
and
33 deletions
+29
-33
vllm/attention/selector.py
vllm/attention/selector.py
+2
-3
vllm/config.py
vllm/config.py
+4
-3
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+3
-2
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+3
-2
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-2
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+3
-3
vllm/platforms/__init__.py
vllm/platforms/__init__.py
+12
-9
vllm/utils.py
vllm/utils.py
+0
-9
No files found.
vllm/attention/selector.py
View file @
4d2dc507
...
@@ -10,8 +10,7 @@ import vllm.envs as envs
...
@@ -10,8 +10,7 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
is_cpu
,
is_hip
,
is_openvino
,
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
is_cpu
,
is_hip
,
is_openvino
,
is_xpu
is_tpu
,
is_xpu
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -194,7 +193,7 @@ def which_attn_to_use(
...
@@ -194,7 +193,7 @@ def which_attn_to_use(
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
logger
.
info
(
"Cannot use %s backend on XPU."
,
selected_backend
)
return
_Backend
.
IPEX
return
_Backend
.
IPEX
if
is_tpu
():
if
current_platform
.
is_tpu
():
if
selected_backend
!=
_Backend
.
PALLAS
:
if
selected_backend
!=
_Backend
.
PALLAS
:
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
return
_Backend
.
PALLAS
return
_Backend
.
PALLAS
...
...
vllm/config.py
View file @
4d2dc507
...
@@ -10,11 +10,12 @@ import vllm.envs as envs
...
@@ -10,11 +10,12 @@ import vllm.envs as envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.platforms
import
current_platform
from
vllm.tracing
import
is_otel_installed
from
vllm.tracing
import
is_otel_installed
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
(
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
GiB_bytes
,
from
vllm.utils
import
(
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
is_cpu
,
cuda_device_count_stateless
,
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
,
is_openvino
,
is_tpu
,
is_xpu
,
is_hip
,
is_neuron
,
is_openvino
,
is_xpu
,
print_warning_once
)
print_warning_once
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -282,7 +283,7 @@ class ModelConfig:
...
@@ -282,7 +283,7 @@ class ModelConfig:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
f
"supported in ROCm."
)
if
is_tpu
(
if
current_platform
.
is_tpu
(
)
and
self
.
quantization
not
in
tpu_supported_quantization
:
)
and
self
.
quantization
not
in
tpu_supported_quantization
:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
...
@@ -910,7 +911,7 @@ class DeviceConfig:
...
@@ -910,7 +911,7 @@ class DeviceConfig:
self
.
device_type
=
"neuron"
self
.
device_type
=
"neuron"
elif
is_openvino
():
elif
is_openvino
():
self
.
device_type
=
"openvino"
self
.
device_type
=
"openvino"
elif
is_tpu
():
elif
current_platform
.
is_tpu
():
self
.
device_type
=
"tpu"
self
.
device_type
=
"tpu"
elif
is_cpu
():
elif
is_cpu
():
self
.
device_type
=
"cpu"
self
.
device_type
=
"cpu"
...
...
vllm/executor/ray_utils.py
View file @
4d2dc507
...
@@ -2,8 +2,9 @@ from typing import List, Optional, Tuple, Union
...
@@ -2,8 +2,9 @@ from typing import List, Optional, Tuple, Union
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.utils
import
get_ip
,
is_hip
,
is_tpu
,
is_xpu
from
vllm.utils
import
get_ip
,
is_hip
,
is_xpu
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -111,7 +112,7 @@ def initialize_ray_cluster(
...
@@ -111,7 +112,7 @@ def initialize_ray_cluster(
# Placement group is already set.
# Placement group is already set.
return
return
device_str
=
"GPU"
if
not
is_tpu
()
else
"TPU"
device_str
=
"GPU"
if
not
current_platform
.
is_tpu
()
else
"TPU"
# Create placement group for worker processes
# Create placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
if
current_placement_group
:
if
current_placement_group
:
...
...
vllm/model_executor/custom_op.py
View file @
4d2dc507
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.utils
import
is_cpu
,
is_hip
,
is_tpu
,
is_xpu
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
class
CustomOp
(
nn
.
Module
):
class
CustomOp
(
nn
.
Module
):
...
@@ -54,7 +55,7 @@ class CustomOp(nn.Module):
...
@@ -54,7 +55,7 @@ class CustomOp(nn.Module):
return
self
.
forward_hip
return
self
.
forward_hip
elif
is_cpu
():
elif
is_cpu
():
return
self
.
forward_cpu
return
self
.
forward_cpu
elif
is_tpu
():
elif
current_platform
.
is_tpu
():
return
self
.
forward_tpu
return
self
.
forward_tpu
elif
is_xpu
():
elif
is_xpu
():
return
self
.
forward_xpu
return
self
.
forward_xpu
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
4d2dc507
...
@@ -28,7 +28,7 @@ import torch
...
@@ -28,7 +28,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.
util
s
import
is_tpu
from
vllm.
platform
s
import
current_platform
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -78,7 +78,7 @@ class RotaryEmbedding(CustomOp):
...
@@ -78,7 +78,7 @@ class RotaryEmbedding(CustomOp):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
self
.
_compute_cos_sin_cache
()
self
.
use_native2
=
is_tpu
()
and
is_neox_style
self
.
use_native2
=
current_platform
.
is_tpu
()
and
is_neox_style
if
not
self
.
use_native2
:
if
not
self
.
use_native2
:
cache
=
cache
.
to
(
dtype
)
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
...
...
vllm/model_executor/model_loader/loader.py
View file @
4d2dc507
...
@@ -41,7 +41,7 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
...
@@ -41,7 +41,7 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
supports_vision
)
supports_vision
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_pin_memory_available
,
is_tpu
from
vllm.utils
import
is_pin_memory_available
@
contextmanager
@
contextmanager
...
@@ -94,7 +94,7 @@ def _get_quantization_config(
...
@@ -94,7 +94,7 @@ def _get_quantization_config(
"""Get the quantization config."""
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
quant_config
=
get_quant_config
(
model_config
,
load_config
)
if
not
is_tpu
():
if
not
current_platform
.
is_tpu
():
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
if
capability
<
quant_config
.
get_min_capability
():
...
@@ -320,7 +320,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -320,7 +320,7 @@ class DefaultModelLoader(BaseModelLoader):
else
:
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
if
is_tpu
():
if
current_platform
.
is_tpu
():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
# not too many ops are accumulated in the XLA program.
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
...
...
vllm/platforms/__init__.py
View file @
4d2dc507
from
typing
import
Optional
import
torch
import
torch
from
vllm.utils
import
is_tpu
from
.interface
import
Platform
,
PlatformEnum
,
UnspecifiedPlatform
from
.interface
import
Platform
,
PlatformEnum
,
UnspecifiedPlatform
current_platform
:
Optional
[
Platform
]
current_platform
:
Platform
if
torch
.
version
.
cuda
is
not
None
:
try
:
import
libtpu
except
ImportError
:
libtpu
=
None
if
libtpu
is
not
None
:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from
.tpu
import
TpuPlatform
current_platform
=
TpuPlatform
()
elif
torch
.
version
.
cuda
is
not
None
:
from
.cuda
import
CudaPlatform
from
.cuda
import
CudaPlatform
current_platform
=
CudaPlatform
()
current_platform
=
CudaPlatform
()
elif
torch
.
version
.
hip
is
not
None
:
elif
torch
.
version
.
hip
is
not
None
:
from
.rocm
import
RocmPlatform
from
.rocm
import
RocmPlatform
current_platform
=
RocmPlatform
()
current_platform
=
RocmPlatform
()
elif
is_tpu
():
from
.tpu
import
TpuPlatform
current_platform
=
TpuPlatform
()
else
:
else
:
current_platform
=
UnspecifiedPlatform
()
current_platform
=
UnspecifiedPlatform
()
...
...
vllm/utils.py
View file @
4d2dc507
...
@@ -333,15 +333,6 @@ def is_neuron() -> bool:
...
@@ -333,15 +333,6 @@ def is_neuron() -> bool:
return
transformers_neuronx
is
not
None
return
transformers_neuronx
is
not
None
@
lru_cache
(
maxsize
=
None
)
def
is_tpu
()
->
bool
:
try
:
import
libtpu
except
ImportError
:
libtpu
=
None
return
libtpu
is
not
None
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
is_xpu
()
->
bool
:
def
is_xpu
()
->
bool
:
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.metadata
import
PackageNotFoundError
,
version
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment