# SPDX-License-Identifier: Apache-2.0 # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/interface.py import enum import random from typing import NamedTuple, Optional, Tuple, Union import numpy as np import torch from fastvideo.v1.logger import init_logger logger = init_logger(__name__) class _Backend(enum.Enum): FLASH_ATTN = enum.auto() SLIDING_TILE_ATTN = enum.auto() TORCH_SDPA = enum.auto() SAGE_ATTN = enum.auto() NO_ATTENTION = enum.auto() class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() TPU = enum.auto() CPU = enum.auto() OOT = enum.auto() UNSPECIFIED = enum.auto() class DeviceCapability(NamedTuple): major: int minor: int def as_version_str(self) -> str: return f"{self.major}.{self.minor}" def to_int(self) -> int: """ Express device capability as an integer ````. It is assumed that the minor version is always a single digit. """ assert 0 <= self.minor < 10 return self.major * 10 + self.minor class Platform: _enum: PlatformEnum device_name: str device_type: str # available dispatch keys: # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa # use "CPU" as a fallback for platforms not registered in PyTorch dispatch_key: str = "CPU" # The torch.compile backend for compiling simple and # standalone functions. The default value is "inductor" to keep # the same behavior as PyTorch. # NOTE: for the forward part of the model, vLLM has another separate # compilation strategy. simple_compile_backend: str = "inductor" supported_quantization: list[str] = [] def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA def is_rocm(self) -> bool: return self._enum == PlatformEnum.ROCM def is_tpu(self) -> bool: return self._enum == PlatformEnum.TPU def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" # TODO(will): ROCM will be supported in the future here return self._enum == PlatformEnum.CUDA @classmethod def get_attn_backend_cls(cls, selected_backend: Optional[_Backend], head_size: int, dtype: torch.dtype) -> str: """Get the attention backend class of a device.""" return "" @classmethod def get_device_capability( cls, device_id: int = 0, ) -> Optional[DeviceCapability]: """Stateless version of :func:`torch.cuda.get_device_capability`.""" return None @classmethod def has_device_capability( cls, capability: Union[Tuple[int, int], int], device_id: int = 0, ) -> bool: """ Test whether this platform is compatible with a device capability. The ``capability`` argument can either be: - A tuple ``(major, minor)``. - An integer ````. (See :meth:`DeviceCapability.to_int`) """ current_capability = cls.get_device_capability(device_id=device_id) if current_capability is None: return False if isinstance(capability, tuple): return current_capability >= capability return current_capability.to_int() >= capability @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" raise NotImplementedError @classmethod def get_device_uuid(cls, device_id: int = 0) -> str: """Get the uuid of a device, e.g. the PCI bus ID.""" raise NotImplementedError @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: """Get the total memory of a device in bytes.""" raise NotImplementedError @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: """ Check if the current platform supports async output. """ raise NotImplementedError @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. This wrapper is recommended because some hardware backends such as TPU do not support `torch.inference_mode`. In such a case, they will fall back to `torch.no_grad` by overriding this method. """ return torch.inference_mode(mode=True) @classmethod def seed_everything(cls, seed: Optional[int] = None) -> None: """ Set the seed of each random module. `torch.manual_seed` will set seed on all devices. Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 """ if seed is not None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) @classmethod def verify_model_arch(cls, model_arch: str) -> None: """ Verify whether the current platform supports the specified model architecture. - This will raise an Error or Warning based on the model support on the current platform. - By default all models are considered supported. """ pass @classmethod def verify_quantization(cls, quant: str) -> None: """ Verify whether the quantization is supported by the current platform. """ if cls.supported_quantization and \ quant not in cls.supported_quantization: raise ValueError( f"{quant} quantization is currently not supported in " f"{cls.device_name}.") @classmethod def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None ) -> float: """ Return the memory usage in bytes. """ raise NotImplementedError @classmethod def get_device_communicator_cls(cls) -> str: """ Get device specific communicator class for distributed communication. """ return "fastvideo.v1.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED device_type = ""