Unverified Commit 8f0d5167 authored by wenxindongwork's avatar wenxindongwork Committed by GitHub
Browse files

[TPU] Support Pathways in vLLM (#21417)


Signed-off-by: default avatarwenxindongwork <wenxindong@google.com>
parent f4135232
......@@ -124,6 +124,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
......@@ -900,6 +901,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_MOST_MODEL_LEN":
lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)),
# Whether using Pathways
"VLLM_TPU_USING_PATHWAYS":
lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()),
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import traceback
from itertools import chain
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname, supports_xccl
......@@ -31,20 +31,26 @@ def vllm_version_matches_substr(substr: str) -> bool:
def tpu_platform_plugin() -> Optional[str]:
is_tpu = False
logger.debug("Checking if TPU platform is available.")
# Check for Pathways TPU proxy
if envs.VLLM_TPU_USING_PATHWAYS:
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
return "tpu_commons.platforms.tpu_jax.TpuPlatform"
# Check for libtpu installation
try:
# While it's technically possible to install libtpu on a
# non-TPU machine, this is a very uncommon scenario. Therefore,
# we assume that libtpu is installed if and only if the machine
# we assume that libtpu is installed only if the machine
# has TPUs.
import libtpu # noqa: F401
is_tpu = True
logger.debug("Confirmed TPU platform is available.")
return "vllm.platforms.tpu.TpuPlatform"
except Exception as e:
logger.debug("TPU platform is not available because: %s", str(e))
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
return None
def cuda_platform_plugin() -> Optional[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment