Unverified Commit 48ac2bed authored by Siyuan Liu's avatar Siyuan Liu Committed by GitHub
Browse files

[Hardware][TPU] Optionally import for TPU backend (#18269)


Signed-off-by: default avatarSiyuan Liu <lsiyuan@google.com>
Signed-off-by: default avatarJade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: default avatarCarol Zheng <cazheng@google.com>
Co-authored-by: default avatarJade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: default avatarHongmin Fan <fanhongmin@google.com>
parent 3e0d4350
...@@ -91,3 +91,12 @@ class TpuCommunicator(DeviceCommunicatorBase): ...@@ -91,3 +91,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather." assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim=dim) return xm.all_gather(input_, dim=dim)
try:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass
...@@ -194,3 +194,11 @@ class TpuPlatform(Platform): ...@@ -194,3 +194,11 @@ class TpuPlatform(Platform):
if params.sampling_type == SamplingType.RANDOM_SEED: if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError( raise ValueError(
"Torch XLA does not support per-request seed.") "Torch XLA does not support per-request seed.")
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
TpuPlatform = TpuCommonsPlatform # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
pass
...@@ -267,3 +267,11 @@ def init_tpu_worker_distributed_environment( ...@@ -267,3 +267,11 @@ def init_tpu_worker_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel) parallel_config.enable_expert_parallel)
try:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
TPUWorker = TPUCommonsWorker # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment