xpu_worker.py 7.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
import os

import torch
import torch.distributed

import vllm.envs as envs
from vllm.config import VllmConfig
10
from vllm.distributed import get_world_group
11
12
13
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
14
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.v1.worker.xpu_model_runner import XPUModelRunner

logger = init_logger(__name__)


class XPUWorker(Worker):
    """A XPU worker class."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
    ):
31
32
33
        super().__init__(
            vllm_config, local_rank, rank, distributed_init_method, is_driver_worker
        )
34
35
36
37
38
39
40
41
        device_config = self.device_config
        assert device_config.device_type == "xpu"
        assert current_platform.is_xpu()

        # Torch profiler. Enabled and configured through env vars:
        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
        if envs.VLLM_TORCH_PROFILER_DIR:
            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
42
43
44
45
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                torch_profiler_trace_dir,
            )
46
47
48
49
50
51
52
53
            logger.debug(
                "Profiler config: record_shapes=%s,"
                "profile_memory=%s,with_stack=%s,with_flops=%s",
                envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                envs.VLLM_TORCH_PROFILER_WITH_STACK,
                envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
            )
54
55
56
57
58
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.XPU,
                ],
59
60
61
62
                record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
                profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
63
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
64
65
66
                    torch_profiler_trace_dir, use_gzip=True
                ),
            )
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        else:
            self.profiler = None

    # we provide this function due to `torch.xpu.mem_get_info()` doesn't
    # return correct free_gpu_memory on intel client GPU. We need to
    # calculate/estiamte it.
    def xpu_get_mem_info(self):
        if current_platform.is_data_center_gpu():
            return torch.xpu.mem_get_info()
        else:
            _, total_gpu_memory = torch.xpu.mem_get_info()
            # FIXME: memory_allocated() doesn't count non-torch allocations,
            # and we don't have any API to get it. so we mark it as 128MB.
            used_memory = torch.xpu.memory_allocated()
            non_torch_allocations = 128 * 1024 * 1024
82
            free_gpu_memory = total_gpu_memory - (used_memory + non_torch_allocations)
83
84
85
86
87
88
89
            return free_gpu_memory, total_gpu_memory

    @torch.inference_mode()
    def determine_available_memory(self) -> int:
        """Profiles the peak memory usage of the model to determine how many
        KV blocks may be allocated without OOMs.
        The engine will first conduct a profiling of the existing memory usage.
90
        Then, it calculates the maximum possible number of GPU and CPU blocks
91
92
93
94
95
96
97
98
99
100
101
102
        that can be allocated with the remaining free memory.
        .. tip::
            You may limit the usage of GPU memory
            by adjusting the `gpu_memory_utilization` parameter.
        """
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.xpu.empty_cache()
        torch.xpu.reset_peak_memory_stats()

        free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
        current_allocated_bytes = torch.xpu.memory_allocated()
103
104
105
106
107
108
        msg = (
            "Before memory profiling run, "
            f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
            f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
            f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
        )
109
110
111
112
113
114
115
116
117
118
119
120
        logger.info(msg)
        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()

        free_gpu_memory, _ = self.xpu_get_mem_info()
        # NOTE(woosuk): Here we assume that the other processes using the same
        # GPU did not change their memory usage during the profiling.
        assert self.init_gpu_memory > free_gpu_memory, (
            "Error in memory profiling. "
            f"Initial free memory {self.init_gpu_memory}, current free memory"
            f" {free_gpu_memory}. This happens when the GPU memory was "
121
122
            "not properly cleaned up before initializing the vLLM instance."
        )
123
124
125
126
127

        # Get the peak memory allocation recorded by torch
        peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]

        torch.xpu.empty_cache()
128
129
        torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"]
        total_allocated_bytes = self.xpu_get_mem_info()[1] - self.xpu_get_mem_info()[0]
130
131
132
133
134

        non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
        if non_torch_allocations > 0:
            peak_memory += non_torch_allocations
        available_kv_cache_memory = (
135
136
137
138
139
140
141
142
143
144
            total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory
        )

        msg = (
            "After memory profiling run, "
            f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
            f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
            f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
            f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
        )
145
146
147
148
149
        logger.info(msg)

        return int(available_kv_cache_memory)

    def init_device(self):
150
        if self.device_config.device.type == "xpu" and current_platform.is_xpu():
151
            self.device = torch.device(f"xpu:{self.local_rank}")
152
            current_platform.set_device(self.device)
153
            current_platform.check_if_supports_dtype(self.model_config.dtype)
154
155
            torch.xpu.empty_cache()
            self.init_gpu_memory = torch.xpu.get_device_properties(
156
157
                self.local_rank
            ).total_memory
158
        else:
159
            raise RuntimeError(f"Not support device type: {self.device_config.device}")
160

161
        ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd")
162
        ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
163
164
165
        ENV_LOCAL_WORLD_SIZE = os.getenv(
            "LOCAL_WORLD_SIZE", str(self.parallel_config.world_size)
        )
166
167
168
169
170
        os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE
        os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
        os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
        os.environ["LOCAL_RANK"] = str(self.local_rank)

171
172
173
174
175
176
177
        init_worker_distributed_environment(
            self.vllm_config,
            self.rank,
            self.distributed_init_method,
            self.local_rank,
            current_platform.dist_backend,
        )
178
179

        # global all_reduce needed for overall oneccl warm up
180
181
182
        torch.distributed.all_reduce(
            torch.zeros(1).xpu(), group=get_world_group().device_group
        )
183
184
185
186
187
188

        # Set random seed.
        set_random_seed(self.model_config.seed)

        # Construct the model runner
        self.model_runner = XPUModelRunner(  # type: ignore
189
190
            self.vllm_config, self.device
        )