Unverified Commit 76afe4ed authored by Aydin Abiar's avatar Aydin Abiar Committed by GitHub
Browse files

[Bugfix] Fix `vllm bench ...` on CPU-only head nodes (#25283)


Signed-off-by: default avatarAydin Abiar <aydin@anyscale.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarAydin Abiar <aydin@anyscale.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent c1b06fc1
...@@ -8,6 +8,11 @@ to avoid certain eager import breakage.""" ...@@ -8,6 +8,11 @@ to avoid certain eager import breakage."""
from __future__ import annotations from __future__ import annotations
import importlib.metadata import importlib.metadata
import sys
from vllm.logger import init_logger
logger = init_logger(__name__)
def main(): def main():
...@@ -29,6 +34,22 @@ def main(): ...@@ -29,6 +34,22 @@ def main():
cli_env_setup() cli_env_setup()
# For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default
if len(sys.argv) > 1 and sys.argv[1] == "bench":
logger.debug(
"Bench command detected, must ensure current platform is not "
"UnspecifiedPlatform to avoid device type inference error"
)
from vllm import platforms
if platforms.current_platform.is_unspecified():
from vllm.platforms.cpu import CpuPlatform
platforms.current_platform = CpuPlatform()
logger.info(
"Unspecified platform detected, switching to CPU Platform instead."
)
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM CLI", description="vLLM CLI",
epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"), epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
......
...@@ -261,4 +261,14 @@ def __getattr__(name: str): ...@@ -261,4 +261,14 @@ def __getattr__(name: str):
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
def __setattr__(name: str, value):
if name == "current_platform":
global _current_platform
_current_platform = value
elif name in globals():
globals()[name] = value
else:
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"] __all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
...@@ -141,6 +141,9 @@ class Platform: ...@@ -141,6 +141,9 @@ class Platform:
def is_out_of_tree(self) -> bool: def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT return self._enum == PlatformEnum.OOT
def is_unspecified(self) -> bool:
return self._enum == PlatformEnum.UNSPECIFIED
def get_max_output_tokens(self, prompt_len: int) -> int: def get_max_output_tokens(self, prompt_len: int) -> int:
return sys.maxsize return sys.maxsize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment