cumem.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
# cumem-based pytorch pluggable allocator to implement sleep mode.
# other approaches tried but failed:
# - cuda-python package binding
# - custom libcuda driver ctypes wrapper
# both of them failed because of cuda context mismatch.
# not sure why, they are created from a different context.
# the only successful approach is to call cuda driver API in C.
import dataclasses
12
import gc
13
import os
14
from collections.abc import Callable
15
from contextlib import contextmanager
16
from typing import Any
17
18
19

import torch

20
from vllm.logger import init_logger
21
22
from vllm.utils import is_pin_memory_available

23
24
logger = init_logger(__name__)

25

26
def find_loaded_library(lib_name) -> str | None:
27
28
29
30
31
    """
    According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
    the file `/proc/self/maps` contains the memory maps of the process, which includes the
    shared libraries loaded by the process. We can use this file to find the path of the
    a loaded library.
32
    """  # noqa
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    found_line = None
    with open("/proc/self/maps") as f:
        for line in f:
            if lib_name in line:
                found_line = line
                break
    if found_line is None:
        # the library is not loaded in the current process
        return None
    # if lib_name is libcudart, we need to match a line with:
    # address /path/to/libcudart-hash.so.11.0
    start = found_line.index("/")
    path = found_line[start:].strip()
    filename = path.split("/")[-1]
47
    assert filename.rpartition(".so")[0].startswith(lib_name), (
48
        f"Unexpected filename: {filename} for library {lib_name}"
49
    )
50
51
52
53
54
    return path


cumem_available = False
try:
55
56
57
58
59
60
61
    from vllm.cumem_allocator import (
        init_module,
        python_create_and_map,
        python_unmap_and_release,
    )
    from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary

62
63
64
65
66
67
68
69
70
71
72
73
74
    lib_name = find_loaded_library("cumem_allocator")
    libcudart = CudaRTLibrary()
    cumem_available = True
except ModuleNotFoundError:
    # rocm platform does not support cumem allocator
    init_module = None
    python_create_and_map = None
    python_unmap_and_release = None
    CudaRTLibrary = None
    lib_name = None
    libcudart = None

# py_device, py_alignedSize, py_d_mem, py_p_memHandle
75
HandleType = tuple[int, int, int, int]
76
77
78
79
80
81


@dataclasses.dataclass
class AllocationData:
    handle: HandleType
    tag: str
82
    cpu_backup_tensor: torch.Tensor | None = None
83
84
85
86
87
88
89
90
91
92
93


def create_and_map(allocation_handle: HandleType) -> None:
    python_create_and_map(*allocation_handle)


def unmap_and_release(allocation_handle: HandleType) -> None:
    python_unmap_and_release(*allocation_handle)


def get_pluggable_allocator(
94
    python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
95
96
97
) -> torch.cuda.memory.CUDAPluggableAllocator:
    init_module(python_malloc_fn, python_free_func)
    new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
98
99
        lib_name, "my_malloc", "my_free"
    )
100
101
102
103
104
    return new_alloc


@contextmanager
def use_memory_pool_with_allocator(
105
106
    python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
) -> None:
107
108
109
    new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
    mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
    with torch.cuda.memory.use_mem_pool(mem_pool):
110
        yield mem_pool, new_alloc
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136


class CuMemAllocator:
    """
    A singleton class that manages a memory pool for CUDA tensors.
    The memory in this pool can be offloaded or discarded when the
    allocator sleeps.

    Inside the `use_memory_pool(tag)` context, all tensors created will
    be allocated in the memory pool, and has the same tag as the
    tag passed to the context.

    When we call `sleep`, all tensors with the specified tag will be
    offloaded to CPU memory, and the rest of the tensors will be discarded.
    When we call `wake_up`, all tensors that are previously offloaded
    will be loaded back to GPU memory, and the rest of the tensors will
    have empty memory.

    Why it needs to be a singleton?
    When allocated tensors are garbage collected, PyTorch will call
    the free callback, which will call the `python_free_callback` method.
    The C-extension uses a global variable to store the function of an
    instance of this class. If we create multiple instances of this class,
    the global variable will be overwritten and the free callback will
    not work as expected.
    """
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    instance: "CuMemAllocator" = None
    default_tag: str = "default"

    @staticmethod
    def get_instance() -> "CuMemAllocator":
        """
        CuMemAllocator is a singleton class.
        We cannot call the constructor directly.
        Call this method to get the instance.
        """
        assert cumem_available, "cumem allocator is not available"
        if CuMemAllocator.instance is None:
            CuMemAllocator.instance = CuMemAllocator()
        return CuMemAllocator.instance

    def __init__(self):
154
        conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
155
156
        assert "expandable_segments:True" not in conf, (
            "Expandable segments are not compatible with memory pool. "
157
            "Please track https://github.com/pytorch/pytorch/issues/147851 "
158
159
            "for the latest updates."
        )
160

161
        self.pointer_to_data: dict[int, AllocationData] = {}
162
        self.current_tag: str = CuMemAllocator.default_tag
163
        self.allocator_and_pools: dict[str, Any] = {}
164
165
166
167
168
        # Creating strong references to the two callbacks here to prevent
        # these ephemeral bound-method objects being garbage collected.
        # See discussions in https://github.com/vllm-project/vllm/pull/22724
        self.python_malloc_callback = self._python_malloc_callback
        self.python_free_callback = self._python_free_callback
169

170
    def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
171
172
173
174
175
        """
        Internal method to store the allocation data
        when memory is allocated in the memory pool."""
        py_d_mem = allocation_handle[2]
        self.pointer_to_data[py_d_mem] = AllocationData(
176
177
            allocation_handle, self.current_tag
        )
178
179
        logger.debug(
            "Allocated %s bytes for %s with address %s from cumem allocator",
180
181
182
183
            allocation_handle[1],
            self.current_tag,
            py_d_mem,
        )
184
185
        return

186
    def _python_free_callback(self, ptr: int) -> HandleType:
187
188
189
190
191
192
        """
        Internal method to look up the allocation data
        when memory is freed in the memory pool."""
        data = self.pointer_to_data.pop(ptr)
        if data.cpu_backup_tensor is not None:
            data.cpu_backup_tensor = None
193
194
        logger.debug(
            "Freed %s bytes for %s with address %s from cumem allocator",
195
196
197
198
            data.handle[1],
            data.tag,
            ptr,
        )
199
200
        return data.handle

201
    def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
202
203
        """
        Put the allocator in sleep mode.
204
        All data in the memory allocation with the specified tag will be
205
206
207
208
209
210
211
212
        offloaded to CPU memory, and others will be discarded.

        :param offload_tags: The tags of the memory allocation that will be
            offloaded. The rest of the memory allocation will be discarded.
        """
        if offload_tags is None:
            # by default, allocated tensors are offloaded
            # when the allocator sleeps
213
            offload_tags = (CuMemAllocator.default_tag,)
214
        elif isinstance(offload_tags, str):
215
            offload_tags = (offload_tags,)
216
217
218

        assert isinstance(offload_tags, tuple)

219
220
221
        total_bytes = 0
        backup_bytes = 0

222
223
        for ptr, data in self.pointer_to_data.items():
            handle = data.handle
224
            total_bytes += handle[1]
225
            if data.tag in offload_tags:
226
                backup_bytes += handle[1]
227
228
229
230
                size_in_bytes = handle[1]
                cpu_backup_tensor = torch.empty(
                    size_in_bytes,
                    dtype=torch.uint8,
231
232
233
                    device="cpu",
                    pin_memory=is_pin_memory_available(),
                )
234
235
236
237
238
                cpu_ptr = cpu_backup_tensor.data_ptr()
                libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
                data.cpu_backup_tensor = cpu_backup_tensor
            unmap_and_release(handle)

239
240
241
        logger.info(
            "CuMemAllocator: sleep freed %.2f GiB memory in total, of which "
            "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
242
243
244
245
246
            "directly.",
            total_bytes / 1024**3,
            backup_bytes / 1024**3,
            (total_bytes - backup_bytes) / 1024**3,
        )
247

248
249
250
        gc.collect()
        torch.cuda.empty_cache()

251
    def wake_up(self, tags: list[str] | None = None) -> None:
252
253
        """
        Wake up the allocator from sleep mode.
254
        All data that is previously offloaded will be loaded back to GPU
255
        memory, and the rest of the data will have empty memory.
256

257
258
259
260
        :param tags: The tags of the memory allocation that will be loaded
            back to GPU memory. If None, all memory allocation will be loaded
            back to GPU memory.
        """
261
        for ptr, data in self.pointer_to_data.items():
262
263
264
265
266
267
            if tags is None or data.tag in tags:
                handle = data.handle
                create_and_map(handle)
                if data.cpu_backup_tensor is not None:
                    cpu_backup_tensor = data.cpu_backup_tensor
                    if cpu_backup_tensor is not None:
268
269
270
                        size_in_bytes = (
                            cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
                        )
271
272
273
                        cpu_ptr = cpu_backup_tensor.data_ptr()
                        libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
                        data.cpu_backup_tensor = None
274
275

    @contextmanager
276
    def use_memory_pool(self, tag: str | None = None):
277
278
        """
        A context manager to use the memory pool.
279
        All memory allocation created inside the context will be allocated
280
281
282
283
284
285
286
287
288
289
290
291
        in the memory pool, and has the specified tag.

        :param tag: The tag of the memory allocation. If None, the default tag
            will be used.
        """
        if tag is None:
            tag = CuMemAllocator.default_tag

        assert isinstance(tag, str)

        old_tag = self.current_tag
        self.current_tag = tag
292
293
294
        with use_memory_pool_with_allocator(
            self.python_malloc_callback, self.python_free_callback
        ) as data:
295
296
297
298
299
300
            # start to hit another PyTorch bug in PyTorch 2.6,
            # possibly because of gc-related issue w.r.t. the allocator and
            # the memory pool.
            # to avoid the issue, we keep a reference of the data.
            # see https://github.com/pytorch/pytorch/issues/146431 .
            self.allocator_and_pools[tag] = data
301
302
303
304
305
            yield
            # PyTorch's bug, calling torch.cuda.empty_cache() will error
            # when using pluggable allocator, see
            # https://github.com/pytorch/pytorch/issues/145168 .
            # if we have some memory allocated and then freed,
306
307
308
309
310
311
312
313
314
315
316
            # the memory will not be released, e.g. in online quantization,
            # where the model is created in higher precision, and then
            # quantized in lower precision.
            # Find all unused allocations and manually release them.
            # TODO: we should expose `empty_cache` method in the memory pool.
            # TODO: ask for help from PyTorch team to expose this method.
            allocations = data[0].snapshot()
            for allocation in allocations:
                if allocation["allocated_size"] == 0:
                    handle = self._python_free_callback(allocation["address"])
                    unmap_and_release(handle)
317
318
319
320
321
322
323
324
325
326
327
            self.current_tag = old_tag

    def get_current_usage(self) -> int:
        """
        Get the total number of bytes allocated in the memory pool.
        """
        sum_bytes: int = 0
        for ptr, data in self.pointer_to_data.items():
            handle = data.handle
            sum_bytes += handle[1]
        return sum_bytes