device.py 2.66 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import field
5
from typing import Any, Literal
6
7
8
9
10

import torch
from pydantic import ConfigDict, SkipValidation

from vllm.config.utils import config
11
from vllm.utils.hashing import safe_hash
12
13
14
15

Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]


16
17
@config(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
18
19
    """Configuration for the device to use for vLLM execution."""

20
    device: SkipValidation[Device | torch.device | None] = "auto"
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    """Device type for vLLM execution.
    This parameter is deprecated and will be
    removed in a future release.
    It will now be set automatically based
    on the current platform."""
    device_type: str = field(init=False)
    """Device type from the current platform. This is set in
    `__post_init__`."""

    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # the device/platform information will be summarized
        # by torch/vllm automatically.
        factors: list[Any] = []
46
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
47
48
49
50
51
52
        return hash_str

    def __post_init__(self):
        if self.device == "auto":
            # Automated device type detection
            from vllm.platforms import current_platform
53

54
55
56
57
58
            self.device_type = current_platform.device_type
            if not self.device_type:
                raise RuntimeError(
                    "Failed to infer device type, please set "
                    "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
59
60
                    "to turn on verbose logging to help debug the issue."
                )
61
62
63
64
65
66
67
68
69
70
71
72
73
        else:
            # Device type is assigned explicitly
            if isinstance(self.device, str):
                self.device_type = self.device
            elif isinstance(self.device, torch.device):
                self.device_type = self.device.type

        # Some device types require processing inputs on CPU
        if self.device_type in ["tpu"]:
            self.device = None
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)