xpu_worker.py 6.93 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
from typing import Any
5
6
7
8
9

import torch
import torch.distributed

from vllm.config import VllmConfig
10
from vllm.distributed import get_world_group
11
12
13
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
14
from vllm.profiler.wrapper import TorchProfilerWrapper
15
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from vllm.v1.worker.xpu_model_runner import XPUModelRunner

logger = init_logger(__name__)


class XPUWorker(Worker):
    """A XPU worker class."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ):
32
33
34
        super().__init__(
            vllm_config, local_rank, rank, distributed_init_method, is_driver_worker
        )
35
36
37
38
        device_config = self.device_config
        assert device_config.device_type == "xpu"
        assert current_platform.is_xpu()

39
        # Torch profiler. Enabled and configured through profiler_config.
40
        self.profiler: Any | None = None
41
42
        profiler_config = vllm_config.profiler_config
        if profiler_config.profiler == "torch":
43
            worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
44
45
46
47
48
            self.profiler = TorchProfilerWrapper(
                profiler_config,
                worker_name=worker_name,
                local_rank=self.local_rank,
                activities=["CPU", "XPU"],
49
            )
50
51
52
53
54
55
56
57
58
59
60
61
62

    # we provide this function due to `torch.xpu.mem_get_info()` doesn't
    # return correct free_gpu_memory on intel client GPU. We need to
    # calculate/estiamte it.
    def xpu_get_mem_info(self):
        if current_platform.is_data_center_gpu():
            return torch.xpu.mem_get_info()
        else:
            _, total_gpu_memory = torch.xpu.mem_get_info()
            # FIXME: memory_allocated() doesn't count non-torch allocations,
            # and we don't have any API to get it. so we mark it as 128MB.
            used_memory = torch.xpu.memory_allocated()
            non_torch_allocations = 128 * 1024 * 1024
63
            free_gpu_memory = total_gpu_memory - (used_memory + non_torch_allocations)
64
65
66
67
68
69
70
            return free_gpu_memory, total_gpu_memory

    @torch.inference_mode()
    def determine_available_memory(self) -> int:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.
        The engine will first conduct a profiling of the existing memory usage.
71
        Then, it calculates the maximum possible number of GPU and CPU blocks
72
73
74
75
76
77
78
79
80
81
82
83
        that can be allocated with the remaining free memory.
        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
        """
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.xpu.empty_cache()
        torch.xpu.reset_peak_memory_stats()

        free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
        current_allocated_bytes = torch.xpu.memory_allocated()
84
85
86
87
88
89
        msg = (
            "Before memory profiling run, "
            f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
            f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
            f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
        )
90
91
92
93
94
95
96
97
98
99
100
101
        logger.info(msg)
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()

        free_gpu_memory, _ = self.xpu_get_mem_info()
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        assert self.init_gpu_memory > free_gpu_memory, (
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
102
103
            "not properly cleaned up before initializing the vLLM instance."
        )
104
105
106
107
108

        # Get the peak memory allocation recorded by torch
        peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]

        torch.xpu.empty_cache()
109
110
        torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"]
        total_allocated_bytes = self.xpu_get_mem_info()[1] - self.xpu_get_mem_info()[0]
111
112
113
114
115

        non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
        if non_torch_allocations > 0:
            peak_memory += non_torch_allocations
        available_kv_cache_memory = (
116
117
118
119
120
121
122
123
124
125
            total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory
        )

        msg = (
            "After memory profiling run, "
            f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
            f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
            f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
            f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
        )
126
127
128
129
130
        logger.info(msg)

        return int(available_kv_cache_memory)

    def init_device(self):
131
132
133
134
135
136
        device = self.device_config.device
        if (
            isinstance(device, torch.device)
            and device.type == "xpu"
            and current_platform.is_xpu()
        ):
137
            self.device = torch.device(f"xpu:{self.local_rank}")
138
            current_platform.set_device(self.device)
139
            current_platform.check_if_supports_dtype(self.model_config.dtype)
140
141
            torch.xpu.empty_cache()
            self.init_gpu_memory = torch.xpu.get_device_properties(
142
143
                self.local_rank
            ).total_memory
144
        else:
145
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
146
147

        ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
148
149
150
        ENV_LOCAL_WORLD_SIZE = os.getenv(
            "LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)
        )
151
152
153
154
        os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
        os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
        os.environ["LOCAL_RANK"] = str(self.local_rank)

155
156
157
158
159
160
161
        init_worker_distributed_environment(
            self.vllm_config,
            self.rank,
            self.distributed_init_method,
            self.local_rank,
            current_platform.dist_backend,
        )
162
163

        # global all_reduce needed for overall oneccl warm up
164
165
166
        torch.distributed.all_reduce(
            torch.zeros(1).xpu(), group=get_world_group().device_group
        )
167
168
169
170
171
172

        # Set random seed.
        set_random_seed(self.model_config.seed)

        # Construct the model runner
        self.model_runner = XPUModelRunner(  # type: ignore
173
174
            self.vllm_config, self.device
        )