Unverified Commit 335b28f7 authored by Utkarsh Sharma's avatar Utkarsh Sharma Committed by GitHub
Browse files

[TPU] Rename tpu_commons to tpu_inference (#26279)


Signed-off-by: default avatarUtkarsh Sharma <utksharma@google.com>
Co-authored-by: default avatarUtkarsh Sharma <utksharma@google.com>
Co-authored-by: default avatarChengji Yao <chengjiyao@google.com>
parent 5e65d6b2
...@@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup ...@@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.platforms.tpu import USE_TPU_INFERENCE
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
...@@ -20,8 +20,8 @@ USE_RAY = parallel_config = ( ...@@ -20,8 +20,8 @@ USE_RAY = parallel_config = (
logger = init_logger(__name__) logger = init_logger(__name__)
if not USE_TPU_COMMONS: if not USE_TPU_INFERENCE:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator") logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu(): if current_platform.is_tpu():
import torch_xla import torch_xla
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
...@@ -100,9 +100,9 @@ class TpuCommunicator(DeviceCommunicatorBase): ...@@ -100,9 +100,9 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim) return xm.all_gather(input_, dim=dim)
if USE_TPU_COMMONS: if USE_TPU_INFERENCE:
from tpu_commons.distributed.device_communicators import ( from tpu_inference.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator, TpuCommunicator as TpuInferenceCommunicator,
) )
TpuCommunicator = TpuCommonsCommunicator # type: ignore TpuCommunicator = TpuInferenceCommunicator # type: ignore
...@@ -223,9 +223,9 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -223,9 +223,9 @@ class DefaultModelLoader(BaseModelLoader):
) )
if current_platform.is_tpu(): if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.platforms.tpu import USE_TPU_INFERENCE
if not USE_TPU_COMMONS: if not USE_TPU_INFERENCE:
# In PyTorch XLA, we should call `torch_xla.sync` # In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated # frequently so that not too many ops are accumulated
# in the XLA program. # in the XLA program.
......
...@@ -37,7 +37,7 @@ def tpu_platform_plugin() -> Optional[str]: ...@@ -37,7 +37,7 @@ def tpu_platform_plugin() -> Optional[str]:
# Check for Pathways TPU proxy # Check for Pathways TPU proxy
if envs.VLLM_TPU_USING_PATHWAYS: if envs.VLLM_TPU_USING_PATHWAYS:
logger.debug("Confirmed TPU platform is available via Pathways proxy.") logger.debug("Confirmed TPU platform is available via Pathways proxy.")
return "tpu_commons.platforms.tpu_jax.TpuPlatform" return "tpu_inference.platforms.tpu_jax.TpuPlatform"
# Check for libtpu installation # Check for libtpu installation
try: try:
......
...@@ -26,7 +26,7 @@ else: ...@@ -26,7 +26,7 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
USE_TPU_COMMONS = False USE_TPU_INFERENCE = False
class TpuPlatform(Platform): class TpuPlatform(Platform):
...@@ -254,10 +254,10 @@ class TpuPlatform(Platform): ...@@ -254,10 +254,10 @@ class TpuPlatform(Platform):
try: try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
TpuPlatform = TpuCommonsPlatform # type: ignore TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_COMMONS = True USE_TPU_INFERENCE = True
except ImportError: except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform") logger.info("tpu_inference not found, using vLLM's TpuPlatform")
pass pass
...@@ -35,7 +35,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -35,7 +35,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
} }
try: try:
import tpu_commons # noqa: F401 import tpu_inference # noqa: F401
except ImportError: except ImportError:
# Lazy import torch_xla # Lazy import torch_xla
import torch_xla.core.xla_builder as xb import torch_xla.core.xla_builder as xb
......
...@@ -23,7 +23,7 @@ from vllm.logger import init_logger ...@@ -23,7 +23,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -36,8 +36,8 @@ logger = init_logger(__name__) ...@@ -36,8 +36,8 @@ logger = init_logger(__name__)
_R = TypeVar("_R") _R = TypeVar("_R")
if not USE_TPU_COMMONS: if not USE_TPU_INFERENCE:
logger.info("tpu_commons not found, using vLLM's TPUWorker.") logger.info("tpu_inference not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr import torch_xla.runtime as xr
...@@ -346,7 +346,7 @@ class TPUWorker: ...@@ -346,7 +346,7 @@ class TPUWorker:
return fn(self.get_model()) return fn(self.get_model())
if USE_TPU_COMMONS: if USE_TPU_INFERENCE:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker from tpu_inference.worker import TPUWorker as TpuInferenceWorker
TPUWorker = TPUCommonsWorker # type: ignore TPUWorker = TpuInferenceWorker # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment