xpu_worker.py 7.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""A XPU worker class."""
import gc
import os
from typing import List, Optional, Tuple

import intel_extension_for_pytorch  # noqa: F401
import oneccl_bindings_for_pytorch  # noqa: F401
import torch
import torch.distributed

11
from vllm.config import VllmConfig
12
13
from vllm.distributed import (ensure_model_parallel_initialized,
                              init_distributed_environment)
14
from vllm.distributed.parallel_state import get_pp_group
15
16
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
17
from vllm.platforms import current_platform
18
19
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
20
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from vllm.worker.xpu_model_runner import XPUModelRunner

logger = init_logger(__name__)


class XPUWorker(LoraNotSupportedWorkerBase, Worker):
    """A worker class that executes (a partition of) the model on a GPU.
    
    Each worker is associated with a single XPU device. The worker is 
    responsible for maintaining the KV cache and executing the model on the 
    XPU. In case of distributed inference, each worker is assigned a partition
    of the model.
    """

    def __init__(
        self,
37
        vllm_config: VllmConfig,
38
39
40
41
42
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ) -> None:
43
44
45
        WorkerBase.__init__(self, vllm_config=vllm_config)
        device_config = self.device_config
        parallel_config = self.parallel_config
46
        assert device_config.device_type == "xpu"
47
        assert current_platform.is_xpu()
48

49
        self.parallel_config.rank = rank
50

51
52
53
54
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
        self.is_driver_worker = is_driver_worker
55
56
57
        if parallel_config and is_driver_worker:
            assert rank % parallel_config.tensor_parallel_size == 0, \
                   "Driver worker should be rank 0 of tensor parallel group."
58
59

        self.model_runner = XPUModelRunner(  # type: ignore
60
            vllm_config=vllm_config,
61
62
63
64
65
            kv_cache_dtype=self.cache_config.cache_dtype,
            is_driver_worker=is_driver_worker,
        )
        # Uninitialized cache engine. Will be initialized by
        # initialize_cache.
66
67
        self.cache_engine: List[CacheEngine]
        self.gpu_cache: Optional[List[List[torch.Tensor]]]
68
69

    def init_device(self) -> None:
70
71
        if self.device_config.device.type == "xpu" and current_platform.is_xpu(
        ):
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            self.device = torch.device(f"xpu:{self.local_rank}")
            torch.xpu.set_device(self.device)
            torch.xpu.empty_cache()
            self.init_gpu_memory = torch.xpu.get_device_properties(
                self.local_rank).total_memory
        else:
            raise RuntimeError(
                f"Not support device type: {self.device_config.device}")
        # Initialize the distributed environment.
        self.init_worker_distributed_environment()
        # Initialize the model.
        set_random_seed(self.model_config.seed)

    # keep this method for `empty_cache` and `synchronize` api
    @torch.inference_mode()
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.

        The engine will first conduct a profiling of the existing memory usage.
        Then, it calculate the maximum possible number of GPU and CPU blocks
        that can be allocated with the remaining free memory.

        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
        """
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.xpu.empty_cache()

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.xpu.synchronize()
        used_memory = torch.xpu.memory_allocated()
        total_gpu_memory = torch.xpu.get_device_properties(
            self.local_rank).total_memory
        free_gpu_memory = total_gpu_memory - used_memory

        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        peak_memory = self.init_gpu_memory - free_gpu_memory
        assert peak_memory > 0, (
119
120
121
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
            "not properly cleaned up before initializing the vLLM instance.")

        cache_block_size = self.get_cache_block_size_bytes()
        num_gpu_blocks = int(
            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
             peak_memory) // cache_block_size)
        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                             cache_block_size)
        num_gpu_blocks = max(num_gpu_blocks, 0)
        num_cpu_blocks = max(num_cpu_blocks, 0)
        gc.collect()
        torch.xpu.empty_cache()
        return num_gpu_blocks, num_cpu_blocks

    def _warm_up_model(self) -> None:
        # IPEX don't support capture graph yet
        pass

    def init_worker_distributed_environment(self) -> None:
        """Initialize the distributed environment."""

        parallel_config = self.parallel_config
        rank = self.rank
        distributed_init_method = self.distributed_init_method

        if torch.distributed.is_initialized():
            torch_world_size = torch.distributed.get_world_size()
            if torch_world_size != parallel_config.world_size:
                raise RuntimeError(
                    "torch.distributed is already initialized but the torch "
                    "world size does not match parallel_config.world_size "
                    f"({torch_world_size} vs. {parallel_config.world_size}).")
        elif not distributed_init_method:
            raise ValueError(
                "distributed_init_method must be set if torch.distributed "
                "is not already initialized")
        else:
            # use sockets as default Level zero IPC exchange backend. By
            # default oneccl will use `drmfd` as mechanism which need extra
            # dependency (libdrm and drm headers) on your system.
162
            ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
163
164
            ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
                                             str(parallel_config.world_size))
165
            os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
166
167
            os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
            os.environ["LOCAL_RANK"] = str(self.local_rank)
168
169
170
171
172
173
174
175
176
177
            init_distributed_environment(
                world_size=parallel_config.world_size,
                rank=rank,
                distributed_init_method=distributed_init_method,
                local_rank=self.local_rank,
                backend="ccl")

        ensure_model_parallel_initialized(
            parallel_config.tensor_parallel_size,
            parallel_config.pipeline_parallel_size)
178
179
        # global all_reduce needed for overall oneccl warm up
        torch.distributed.all_reduce(torch.zeros(1).xpu())
180
181
182
183
184

        if parallel_config.pipeline_parallel_size > 1:
            # Add pp group init to avoid
            # p2p communication as the first call
            get_pp_group().all_reduce(torch.zeros(1).xpu())