"docs/vscode:/vscode.git/clone" did not exist on "e68b0dad8b4070e3ae24603c12f53b6c659ba6f9"
tpu.py 839 Bytes
Newer Older
1
2
import os

3
4
import torch

5
6
7
8
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_torch_compile_backend

9
10
from .interface import Platform, PlatformEnum

11
12
13
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
    os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)

14
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\
15
16
17
18
     "TPU does not support Inductor."

set_torch_compile_backend("openxla")

19
20
21
22

class TpuPlatform(Platform):
    _enum = PlatformEnum.TPU

23
24
25
26
    @classmethod
    def get_device_name(cls, device_id: int = 0) -> str:
        raise NotImplementedError

27
28
29
30
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        raise NotImplementedError

31
32
    @classmethod
    def inference_mode(cls):
33
        return torch.no_grad()