Unverified Commit 1a8bfd92 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Hardware] Initial TPU integration (#5292)

parent 847cdcca
ARG NIGHTLY_DATE="20240601"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE
WORKDIR /workspace
COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu"
# Install aiohttp separately to avoid build errors.
RUN pip install aiohttp
# Install the TPU and Pallas dependencies.
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Build vLLM.
RUN cd /workspace/vllm && python setup.py develop
CMD ["/bin/bash"]
...@@ -189,7 +189,7 @@ if __name__ == '__main__': ...@@ -189,7 +189,7 @@ if __name__ == '__main__':
"--device", "--device",
type=str, type=str,
default="cuda", default="cuda",
choices=["cuda", "cpu"], choices=["cuda", "cpu", "tpu"],
help='device type for vLLM execution, supporting CUDA and CPU.') help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
......
...@@ -346,7 +346,7 @@ if __name__ == "__main__": ...@@ -346,7 +346,7 @@ if __name__ == "__main__":
"--device", "--device",
type=str, type=str,
default="cuda", default="cuda",
choices=["cuda", "cpu"], choices=["cuda", "cpu", "tpu"],
help='device type for vLLM execution, supporting CUDA and CPU.') help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument( parser.add_argument(
"--enable-prefix-caching", "--enable-prefix-caching",
......
.. _installation_tpu:
Installation with TPU
=====================
vLLM supports Google Cloud TPUs using PyTorch XLA.
Requirements
------------
* Google Cloud TPU VM (single host)
* TPU versions: v5e, v5p, v4
* Python: 3.10
Installation options:
1. :ref:`Build a docker image with Dockerfile <build_docker_tpu>`.
2. :ref:`Build from source <build_from_source_tpu>`.
.. _build_docker_tpu:
Build a docker image with :code:`Dockerfile.tpu`
------------------------------------------------
`Dockerfile.tpu <https://github.com/vllm-project/vllm/blob/main/Dockerfile.tpu>`_ is provided to build a docker image with TPU support.
.. code-block:: console
$ docker build -f Dockerfile.tpu -t vllm-tpu .
You can run the docker image with the following command:
.. code-block:: console
$ # Make sure to add `--privileged --net host --shm-size=16G`.
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
.. _build_from_source_tpu:
Build from source
-----------------
You can also build and install the TPU backend from source.
First, install the dependencies:
.. code-block:: console
$ # (Recommended) Create a new conda environment.
$ conda create -n myenv python=3.10 -y
$ conda activate myenv
$ # Clean up the existing torch and torch-xla packages.
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240601"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ # Install JAX and Pallas.
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install packaging aiohttp
Next, build vLLM from source. This will only take a few seconds:
.. code-block:: console
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
...@@ -63,8 +63,9 @@ Documentation ...@@ -63,8 +63,9 @@ Documentation
getting_started/installation getting_started/installation
getting_started/amd-installation getting_started/amd-installation
getting_started/neuron-installation
getting_started/cpu-installation getting_started/cpu-installation
getting_started/neuron-installation
getting_started/tpu-installation
getting_started/quickstart getting_started/quickstart
getting_started/debugging getting_started/debugging
getting_started/examples/examples_index getting_started/examples/examples_index
......
# Common dependencies
-r requirements-common.txt
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
triton # To avoid import errors
...@@ -206,9 +206,9 @@ class cmake_build_ext(build_ext): ...@@ -206,9 +206,9 @@ class cmake_build_ext(build_ext):
def _is_cuda() -> bool: def _is_cuda() -> bool:
return VLLM_TARGET_DEVICE == "cuda" \ has_cuda = torch.version.cuda is not None
and torch.version.cuda is not None \ return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
and not _is_neuron() and not (_is_neuron() or _is_tpu()))
def _is_hip() -> bool: def _is_hip() -> bool:
...@@ -225,10 +225,18 @@ def _is_neuron() -> bool: ...@@ -225,10 +225,18 @@ def _is_neuron() -> bool:
return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron" return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron"
def _is_tpu() -> bool:
return VLLM_TARGET_DEVICE == "tpu"
def _is_cpu() -> bool: def _is_cpu() -> bool:
return VLLM_TARGET_DEVICE == "cpu" return VLLM_TARGET_DEVICE == "cpu"
def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()
def _install_punica() -> bool: def _install_punica() -> bool:
return envs.VLLM_INSTALL_PUNICA_KERNELS return envs.VLLM_INSTALL_PUNICA_KERNELS
...@@ -325,6 +333,8 @@ def get_vllm_version() -> str: ...@@ -325,6 +333,8 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION: if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3] neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}" version += f"+neuron{neuron_version_str}"
elif _is_tpu():
version += "+tpu"
elif _is_cpu(): elif _is_cpu():
version += "+cpu" version += "+cpu"
else: else:
...@@ -372,6 +382,8 @@ def get_requirements() -> List[str]: ...@@ -372,6 +382,8 @@ def get_requirements() -> List[str]:
requirements = _read_requirements("requirements-rocm.txt") requirements = _read_requirements("requirements-rocm.txt")
elif _is_neuron(): elif _is_neuron():
requirements = _read_requirements("requirements-neuron.txt") requirements = _read_requirements("requirements-neuron.txt")
elif _is_tpu():
requirements = _read_requirements("requirements-tpu.txt")
elif _is_cpu(): elif _is_cpu():
requirements = _read_requirements("requirements-cpu.txt") requirements = _read_requirements("requirements-cpu.txt")
else: else:
...@@ -385,7 +397,7 @@ ext_modules = [] ...@@ -385,7 +397,7 @@ ext_modules = []
if _is_cuda() or _is_hip(): if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C")) ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if not _is_neuron(): if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C")) ext_modules.append(CMakeExtension(name="vllm._C"))
if _install_punica(): if _install_punica():
...@@ -428,6 +440,6 @@ setup( ...@@ -428,6 +440,6 @@ setup(
extras_require={ extras_require={
"tensorizer": ["tensorizer>=2.9.0"], "tensorizer": ["tensorizer>=2.9.0"],
}, },
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
package_data=package_data, package_data=package_data,
) )
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "PallasMetadata":
return PallasMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError("swap_blocks is not implemented.")
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
# TODO(woosuk): Implement this.
raise NotImplementedError("copy_blocks is not implemented.")
@dataclass
class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor]
context_lens: Optional[torch.Tensor]
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None
assert self.num_decode_tokens == 0
assert self.block_tables is None
assert self.context_lens is None
return self
@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
if not tpu_type.endswith("lite"):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
attn_metadata: PallasMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert kv_scale == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache[0] is not None:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
if attn_metadata.num_prefills > 0:
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Decoding run.
assert kv_cache is not None
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
if self.megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = self.megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
)
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip from vllm.utils import is_cpu, is_hip, is_tpu
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -18,6 +18,7 @@ class _Backend(enum.Enum): ...@@ -18,6 +18,7 @@ class _Backend(enum.Enum):
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
PALLAS = enum.auto()
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
...@@ -66,6 +67,10 @@ def get_attn_backend( ...@@ -66,6 +67,10 @@ def get_attn_backend(
"Please make sure --enforce-eager is set.") "Please make sure --enforce-eager is set.")
from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend return FlashInferBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
else: else:
raise ValueError("Invalid attention backend.") raise ValueError("Invalid attention backend.")
...@@ -80,7 +85,6 @@ def which_attn_to_use( ...@@ -80,7 +85,6 @@ def which_attn_to_use(
block_size: int, block_size: int,
) -> _Backend: ) -> _Backend:
"""Returns which flash attention backend to use.""" """Returns which flash attention backend to use."""
# Default case. # Default case.
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
...@@ -100,6 +104,11 @@ def which_attn_to_use( ...@@ -100,6 +104,11 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA return _Backend.TORCH_SDPA
if is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if is_hip(): if is_hip():
# AMD GPUs. # AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
......
...@@ -11,7 +11,7 @@ from vllm.logger import init_logger ...@@ -11,7 +11,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -748,6 +748,8 @@ class DeviceConfig: ...@@ -748,6 +748,8 @@ class DeviceConfig:
# Automated device type detection # Automated device type detection
if is_neuron(): if is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_tpu():
self.device_type = "tpu"
elif is_cpu(): elif is_cpu():
self.device_type = "cpu" self.device_type = "cpu"
else: else:
...@@ -761,6 +763,8 @@ class DeviceConfig: ...@@ -761,6 +763,8 @@ class DeviceConfig:
# Some device types require processing inputs on CPU # Some device types require processing inputs on CPU
if self.device_type in ["neuron"]: if self.device_type in ["neuron"]:
self.device = torch.device("cpu") self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
self.device = None
else: else:
# Set device with device type # Set device with device type
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
......
...@@ -504,7 +504,7 @@ class EngineArgs: ...@@ -504,7 +504,7 @@ class EngineArgs:
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
default=EngineArgs.device, default=EngineArgs.device,
choices=["auto", "cuda", "neuron", "cpu"], choices=["auto", "cuda", "neuron", "cpu", "tpu"],
help='Device type for vLLM execution.') help='Device type for vLLM execution.')
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
......
...@@ -375,6 +375,9 @@ class AsyncLLMEngine: ...@@ -375,6 +375,9 @@ class AsyncLLMEngine:
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
assert distributed_executor_backend is None, ( assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.") "Distributed execution is not supported with the CPU backend.")
......
...@@ -341,6 +341,9 @@ class LLMEngine: ...@@ -341,6 +341,9 @@ class LLMEngine:
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor executor_class = CPUExecutor
......
...@@ -27,6 +27,7 @@ if TYPE_CHECKING: ...@@ -27,6 +27,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
...@@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Default is 5 seconds # Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT": "VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH":
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
} }
# end-env-vars-definition # end-env-vars-definition
......
from typing import List, Set, Tuple
import torch
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
class TPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert not self.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not self.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if self.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead.", self.model_config.dtype)
self.model_config.dtype = torch.bfloat16
# Instantiate the worker and load the model to the device.
self._init_worker()
def _init_worker(self):
from vllm.worker.tpu_worker import TPUWorker
assert self.parallel_config.world_size == 1, (
"TPUExecutor currently only supports a single TPU chip.")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = TPUWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
self.vision_language_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def initialize_cache(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def list_loras(self) -> Set[int]:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def check_health(self) -> None:
# TPUExecutor will always be healthy as long as it's running.
return
class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
sexecute_model_req: ExecuteModelRequest,
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model
)(sexecute_model_req)
return output
import torch.nn as nn import torch.nn as nn
from vllm.utils import is_cpu, is_hip from vllm.utils import is_cpu, is_hip, is_tpu
class CustomOp(nn.Module): class CustomOp(nn.Module):
...@@ -56,5 +56,7 @@ class CustomOp(nn.Module): ...@@ -56,5 +56,7 @@ class CustomOp(nn.Module):
return self.forward_hip return self.forward_hip
elif is_cpu(): elif is_cpu():
return self.forward_cpu return self.forward_cpu
elif is_tpu():
return self.forward_tpu
else: else:
return self.forward_cuda return self.forward_cuda
...@@ -28,6 +28,7 @@ import torch ...@@ -28,6 +28,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_tpu
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: ...@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2) return x.flatten(-2)
def _apply_rotary_emb(
x: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
class RotaryEmbedding(CustomOp): class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
...@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp): ...@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
self.dtype = dtype self.dtype = dtype
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
self.use_native2 = is_tpu() and is_neox_style
if not self.use_native2:
cache = cache.to(dtype) cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
else:
cos, sin = cache.chunk(2, dim=-1)
freqs_cis = cos + 1j * sin
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
...@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp): ...@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """A PyTorch-native implementation equivalent to forward().
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query = query.view(*query.shape[:-1], -1, self.head_size) query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size)
...@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp): ...@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
key = key.flatten(-2) key = key.flatten(-2)
return query, key return query, key
def forward_native2(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
if positions.dim() == 1:
batch_size = 1
seq_len = positions.shape[0]
else:
batch_size, seq_len = positions.shape
if offsets is not None:
positions = positions + offsets
freqs_cis = self.freqs_cis.index_select(0, positions.flatten())
freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1)
query_shape = query.shape
query = query.view(batch_size, seq_len, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, freqs_cis)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(batch_size, seq_len, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, freqs_cis)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda( def forward_cuda(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -161,6 +221,17 @@ class RotaryEmbedding(CustomOp): ...@@ -161,6 +221,17 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style) self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
def forward_tpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
forward_fn = (self.forward_native2
if self.use_native2 else self.forward_native)
return forward_fn(positions, query, key, offsets)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}" s += f", max_position_embeddings={self.max_position_embeddings}"
......
...@@ -34,6 +34,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -34,6 +34,7 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator, safetensors_weights_iterator) pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_tpu
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -227,12 +228,26 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -227,12 +228,26 @@ class DefaultModelLoader(BaseModelLoader):
if self.load_config.load_format == LoadFormat.NPCACHE: if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints
assert use_safetensors is False assert use_safetensors is False
return np_cache_weights_iterator(model_name_or_path, weights_iterator = np_cache_weights_iterator(
self.load_config.download_dir, model_name_or_path, self.load_config.download_dir, hf_folder,
hf_folder, hf_weights_files) hf_weights_files)
if use_safetensors: elif use_safetensors:
return safetensors_weights_iterator(hf_weights_files) weights_iterator = safetensors_weights_iterator(hf_weights_files)
return pt_weights_iterator(hf_weights_files) else:
weights_iterator = pt_weights_iterator(hf_weights_files)
if is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
......
...@@ -146,6 +146,15 @@ def is_neuron() -> bool: ...@@ -146,6 +146,15 @@ def is_neuron() -> bool:
return transformers_neuronx is not None return transformers_neuronx is not None
@lru_cache(maxsize=None)
def is_tpu() -> bool:
try:
import libtpu
except ImportError:
libtpu = None
return libtpu is not None
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
...@@ -546,6 +555,11 @@ def maybe_expand_dim(tensor: torch.Tensor, ...@@ -546,6 +555,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
return tensor return tensor
def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size()
def merge_dicts(dict1: Dict[Any, List[Any]], def merge_dicts(dict1: Dict[Any, List[Any]],
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
"""Merge 2 dicts that have key -> List of items. """Merge 2 dicts that have key -> List of items.
......
...@@ -6,7 +6,8 @@ import torch ...@@ -6,7 +6,8 @@ import torch
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
is_pin_memory_available)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -108,9 +109,5 @@ class CacheEngine: ...@@ -108,9 +109,5 @@ class CacheEngine:
dtype = model_config.dtype dtype = model_config.dtype
else: else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = _get_dtype_size(dtype) dtype_size = get_dtype_size(dtype)
return dtype_size * total return dtype_size * total
def _get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment