Unverified Commit 335b28f7 authored by Utkarsh Sharma's avatar Utkarsh Sharma Committed by GitHub
Browse files

[TPU] Rename tpu_commons to tpu_inference (#26279)


Signed-off-by: default avatarUtkarsh Sharma <utksharma@google.com>
Co-authored-by: default avatarUtkarsh Sharma <utksharma@google.com>
Co-authored-by: default avatarChengji Yao <chengjiyao@google.com>
parent 5e65d6b2
......@@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE
from .base_device_communicator import DeviceCommunicatorBase
......@@ -20,8 +20,8 @@ USE_RAY = parallel_config = (
logger = init_logger(__name__)
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
......@@ -100,9 +100,9 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim)
if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator,
if USE_TPU_INFERENCE:
from tpu_inference.distributed.device_communicators import (
TpuCommunicator as TpuInferenceCommunicator,
)
TpuCommunicator = TpuCommonsCommunicator # type: ignore
TpuCommunicator = TpuInferenceCommunicator # type: ignore
......@@ -223,9 +223,9 @@ class DefaultModelLoader(BaseModelLoader):
)
if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE
if not USE_TPU_COMMONS:
if not USE_TPU_INFERENCE:
# In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated
# in the XLA program.
......
......@@ -37,7 +37,7 @@ def tpu_platform_plugin() -> Optional[str]:
# Check for Pathways TPU proxy
if envs.VLLM_TPU_USING_PATHWAYS:
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
return "tpu_commons.platforms.tpu_jax.TpuPlatform"
return "tpu_inference.platforms.tpu_jax.TpuPlatform"
# Check for libtpu installation
try:
......
......@@ -26,7 +26,7 @@ else:
logger = init_logger(__name__)
USE_TPU_COMMONS = False
USE_TPU_INFERENCE = False
class TpuPlatform(Platform):
......@@ -254,10 +254,10 @@ class TpuPlatform(Platform):
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
TpuPlatform = TpuCommonsPlatform # type: ignore
USE_TPU_COMMONS = True
TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
pass
......@@ -35,7 +35,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
}
try:
import tpu_commons # noqa: F401
import tpu_inference # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
......
......@@ -23,7 +23,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.core.sched.output import SchedulerOutput
......@@ -36,8 +36,8 @@ logger = init_logger(__name__)
_R = TypeVar("_R")
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
......@@ -346,7 +346,7 @@ class TPUWorker:
return fn(self.get_model())
if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
if USE_TPU_INFERENCE:
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
TPUWorker = TPUCommonsWorker # type: ignore
TPUWorker = TpuInferenceWorker # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment