xpu_worker.py 3.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
from typing import Any
5
6
7
8
9
10
11

import torch
import torch.distributed

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
12
from vllm.profiler.wrapper import TorchProfilerWrapper
13
from vllm.utils.mem_utils import MemorySnapshot, format_gib
14
from vllm.utils.torch_utils import set_random_seed
15
from vllm.v1.utils import report_usage_stats
16
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
17
from vllm.v1.worker.workspace import init_workspace_manager
18
19
from vllm.v1.worker.xpu_model_runner import XPUModelRunner

20
21
from .utils import request_memory

22
23
24
25
26
27
28
29
30
31
32
33
34
35
logger = init_logger(__name__)


class XPUWorker(Worker):
    """A XPU worker class."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ):
36
37
38
        super().__init__(
            vllm_config, local_rank, rank, distributed_init_method, is_driver_worker
        )
39
40
41
42
        device_config = self.device_config
        assert device_config.device_type == "xpu"
        assert current_platform.is_xpu()

43
        # Torch profiler. Enabled and configured through profiler_config.
44
        self.profiler: Any | None = None
45
46
        profiler_config = vllm_config.profiler_config
        if profiler_config.profiler == "torch":
47
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
48
49
50
51
52
            self.profiler = TorchProfilerWrapper(
                profiler_config,
                worker_name=worker_name,
                local_rank=self.local_rank,
                activities=["CPU", "XPU"],
53
            )
54
55

    def init_device(self):
56
57
58
59
60
61
        device = self.device_config.device
        if (
            isinstance(device, torch.device)
            and device.type == "xpu"
            and current_platform.is_xpu()
        ):
62
            self.device = torch.device(f"xpu:{self.local_rank}")
63
            current_platform.set_device(self.device)
64
            current_platform.check_if_supports_dtype(self.model_config.dtype)
65
66
            torch.xpu.empty_cache()
            self.init_gpu_memory = torch.xpu.get_device_properties(
67
68
                self.local_rank
            ).total_memory
69
        else:
70
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
71
72

        ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
73
74
75
        ENV_LOCAL_WORLD_SIZE = os.getenv(
            "LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)
        )
76
77
78
79
        os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
        os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
        os.environ["LOCAL_RANK"] = str(self.local_rank)

80
81
82
83
84
85
86
        init_worker_distributed_environment(
            self.vllm_config,
            self.rank,
            self.distributed_init_method,
            self.local_rank,
            current_platform.dist_backend,
        )
87

88
89
90
91
92
93
        torch.xpu.empty_cache()
        self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
        self.requested_memory = request_memory(init_snapshot, self.cache_config)
        logger.debug("worker init memory snapshot: %r", self.init_snapshot)
        logger.debug(
            "worker requested memory: %sGiB", format_gib(self.requested_memory)
94
        )
95
96
97
98

        # Set random seed.
        set_random_seed(self.model_config.seed)

99
100
101
102
        # Initialize workspace manager
        num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
        init_workspace_manager(self.device, num_ubatches)

103
104
        # Construct the model runner
        self.model_runner = XPUModelRunner(  # type: ignore
105
106
            self.vllm_config, self.device
        )
107
108
109
110

        if self.rank == 0:
            # If usage stat is enabled, collect relevant info.
            report_usage_stats(self.vllm_config)