xpu_worker.py 7.79 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
9
10
11
"""A XPU worker class."""
import gc
import os
from typing import List, Optional, Tuple

import intel_extension_for_pytorch  # noqa: F401
import oneccl_bindings_for_pytorch  # noqa: F401
import torch
import torch.distributed

12
from vllm.config import VllmConfig
13
14
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
15
from vllm.distributed.parallel_state import get_pp_group
16
17
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
18
from vllm.platforms import current_platform
19
20
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
21
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
22
23
24
25
26
from vllm.worker.xpu_model_runner import XPUModelRunner

logger = init_logger(__name__)


27
class XPUWorker(LoRANotSupportedWorkerBase, Worker):
28
    """A worker class that executes (a partition of) the model on a GPU.
29

30
31
32
33
34
35
36
37
    Each worker is associated with a single XPU device. The worker is 
    responsible for maintaining the KV cache and executing the model on the 
    XPU. In case of distributed inference, each worker is assigned a partition
    of the model.
    """

    def __init__(
        self,
38
        vllm_config: VllmConfig,
39
40
41
42
43
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ) -> None:
44
45
46
        WorkerBase.__init__(self, vllm_config=vllm_config)
        device_config = self.device_config
        parallel_config = self.parallel_config
47
        assert device_config.device_type == "xpu"
48
        assert current_platform.is_xpu()
49

50
        self.parallel_config.rank = rank
51

52
53
54
55
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
        self.is_driver_worker = is_driver_worker
56
57
58
        if parallel_config and is_driver_worker:
            assert rank % parallel_config.tensor_parallel_size == 0, \
                   "Driver worker should be rank 0 of tensor parallel group."
59
60

        self.model_runner = XPUModelRunner(  # type: ignore
61
            vllm_config=vllm_config,
62
63
64
65
66
            kv_cache_dtype=self.cache_config.cache_dtype,
            is_driver_worker=is_driver_worker,
        )
        # Uninitialized cache engine. Will be initialized by
        # initialize_cache.
67
68
        self.cache_engine: List[CacheEngine]
        self.gpu_cache: Optional[List[List[torch.Tensor]]]
69
70

    def init_device(self) -> None:
71
72
        if self.device_config.device.type == "xpu" and current_platform.is_xpu(
        ):
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            self.device = torch.device(f"xpu:{self.local_rank}")
            torch.xpu.set_device(self.device)
            torch.xpu.empty_cache()
            self.init_gpu_memory = torch.xpu.get_device_properties(
                self.local_rank).total_memory
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
        # Initialize the distributed environment.
        self.init_worker_distributed_environment()
        # Initialize the model.
        set_random_seed(self.model_config.seed)

    # keep this method for `empty_cache` and `synchronize` api
    @torch.inference_mode()
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

96
97
98
        Tip:
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        """
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.xpu.empty_cache()

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.xpu.synchronize()
        used_memory = torch.xpu.memory_allocated()
        total_gpu_memory = torch.xpu.get_device_properties(
            self.local_rank).total_memory
        free_gpu_memory = total_gpu_memory - used_memory

        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
        assert peak_memory > 0, (
120
121
122
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            "not properly cleaned up before initializing the vLLM instance.")

        cache_block_size = self.get_cache_block_size_bytes()
        num_gpu_blocks = int(
            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
             peak_memory) // cache_block_size)
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             cache_block_size)
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
        gc.collect()
        torch.xpu.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

    def _warm_up_model(self) -> None:
        # IPEX don't support capture graph yet
        pass

    def init_worker_distributed_environment(self) -> None:
        """Initialize the distributed environment."""

        parallel_config = self.parallel_config
        rank = self.rank
        distributed_init_method = self.distributed_init_method

        if torch.distributed.is_initialized():
            torch_world_size = torch.distributed.get_world_size()
            if torch_world_size != parallel_config.world_size:
                raise RuntimeError(
                    "torch.distributed is already initialized but the torch "
                    "world size does not match parallel_config.world_size "
                    f"({torch_world_size} vs. {parallel_config.world_size}).")
        elif not distributed_init_method:
            raise ValueError(
                "distributed_init_method must be set if torch.distributed "
                "is not already initialized")
        else:
            # use sockets as default Level zero IPC exchange backend. By
            # default oneccl will use `drmfd` as mechanism which need extra
            # dependency (libdrm and drm headers) on your system.
163
            ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
164
165
            ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
                                             str(parallel_config.world_size))
166
            os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
167
168
            os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
            os.environ["LOCAL_RANK"] = str(self.local_rank)
169
170
171
172
173
174
175
176
177
            init_distributed_environment(
                world_size=parallel_config.world_size,
                rank=rank,
                distributed_init_method=distributed_init_method,
                local_rank=self.local_rank,
                backend="ccl")

        ensure_model_parallel_initialized(
            parallel_config.tensor_parallel_size,
178
179
            parallel_config.pipeline_parallel_size,
            parallel_config.enable_expert_parallel)
180
181
        # global all_reduce needed for overall oneccl warm up
        torch.distributed.all_reduce(torch.zeros(1).xpu())
182
183
184
185
186

        if parallel_config.pipeline_parallel_size > 1:
            # Add pp group init to avoid
            # p2p communication as the first call
            get_pp_group().all_reduce(torch.zeros(1).xpu())