Unverified Commit 7f8d612d authored by Earthwalker's avatar Earthwalker Committed by GitHub
Browse files

[TPU] Support tensor parallelism in async llm engine (#6891)

parent 60d1c6e5
......@@ -12,6 +12,9 @@ RUN pip install "numpy<2"
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Fix FastAPI dependence
RUN pip install "starlette<0.38.0"
# Build vLLM.
COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu"
......
......@@ -407,6 +407,12 @@ class AsyncLLMEngine:
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment