cumem.py 11.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
# cumem-based pytorch pluggable allocator to implement sleep mode.
# other approaches tried but failed:
# - cuda-python package binding
# - custom libcuda driver ctypes wrapper
# both of them failed because of cuda context mismatch.
# not sure why, they are created from a different context.
# the only successful approach is to call cuda driver API in C.
import dataclasses
12
import gc
13
import os
14
from collections.abc import Callable, Iterator
15
from contextlib import contextmanager
16
from typing import Any
17
18
19

import torch

20
from vllm.logger import init_logger
21
from vllm.utils.platform_utils import is_pin_memory_available
22
from vllm.utils.system_utils import find_loaded_library
23

24
25
logger = init_logger(__name__)

26
27

cumem_available = False
28
libcudart: Any = None
29
try:
30
31
32
33
34
35
36
    from vllm.cumem_allocator import (
        init_module,
        python_create_and_map,
        python_unmap_and_release,
    )
    from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary

37
38
39
40
    lib_name = find_loaded_library("cumem_allocator")
    libcudart = CudaRTLibrary()
    cumem_available = True
except ModuleNotFoundError:
41
    # only cuda and rocm platforms support cumem allocator
42
43
44
45
46
47
    init_module = None
    python_create_and_map = None
    python_unmap_and_release = None
    lib_name = None

# py_device, py_alignedSize, py_d_mem, py_p_memHandle
48
HandleType = tuple[int, int, int, int]
49
50
51
52
53
54


@dataclasses.dataclass
class AllocationData:
    handle: HandleType
    tag: str
55
    cpu_backup_tensor: torch.Tensor | None = None
56
57
58
59
60
61
62
63
64
65
66


def create_and_map(allocation_handle: HandleType) -> None:
    python_create_and_map(*allocation_handle)


def unmap_and_release(allocation_handle: HandleType) -> None:
    python_unmap_and_release(*allocation_handle)


def get_pluggable_allocator(
67
68
    python_malloc_fn: Callable[[HandleType], None],
    python_free_func: Callable[[int], HandleType],
69
70
71
) -> torch.cuda.memory.CUDAPluggableAllocator:
    init_module(python_malloc_fn, python_free_func)
    new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
72
73
        lib_name, "my_malloc", "my_free"
    )
74
75
76
77
78
    return new_alloc


@contextmanager
def use_memory_pool_with_allocator(
79
80
81
82
83
    python_malloc_fn: Callable[[HandleType], None],
    python_free_func: Callable[[int], HandleType],
) -> Iterator[
    tuple[torch.cuda.memory.MemPool, torch.cuda.memory.CUDAPluggableAllocator]
]:
84
85
86
    new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
    mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
    with torch.cuda.memory.use_mem_pool(mem_pool):
87
        yield mem_pool, new_alloc
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113


class CuMemAllocator:
    """
    A singleton class that manages a memory pool for CUDA tensors.
    The memory in this pool can be offloaded or discarded when the
    allocator sleeps.

    Inside the `use_memory_pool(tag)` context, all tensors created will
    be allocated in the memory pool, and has the same tag as the
    tag passed to the context.

    When we call `sleep`, all tensors with the specified tag will be
    offloaded to CPU memory, and the rest of the tensors will be discarded.
    When we call `wake_up`, all tensors that are previously offloaded
    will be loaded back to GPU memory, and the rest of the tensors will
    have empty memory.

    Why it needs to be a singleton?
    When allocated tensors are garbage collected, PyTorch will call
    the free callback, which will call the `python_free_callback` method.
    The C-extension uses a global variable to store the function of an
    instance of this class. If we create multiple instances of this class,
    the global variable will be overwritten and the free callback will
    not work as expected.
    """
114

115
    instance: "CuMemAllocator | None" = None
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    default_tag: str = "default"

    @staticmethod
    def get_instance() -> "CuMemAllocator":
        """
        CuMemAllocator is a singleton class.
        We cannot call the constructor directly.
        Call this method to get the instance.
        """
        assert cumem_available, "cumem allocator is not available"
        if CuMemAllocator.instance is None:
            CuMemAllocator.instance = CuMemAllocator()
        return CuMemAllocator.instance

    def __init__(self):
131
        conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
132
133
        assert "expandable_segments:True" not in conf, (
            "Expandable segments are not compatible with memory pool. "
134
            "Please track https://github.com/pytorch/pytorch/issues/147851 "
135
136
            "for the latest updates."
        )
137

138
        self.pointer_to_data: dict[int, AllocationData] = {}
139
        self.current_tag: str = CuMemAllocator.default_tag
140
        self.allocator_and_pools: dict[str, Any] = {}
141
142
143
144
145
        # Creating strong references to the two callbacks here to prevent
        # these ephemeral bound-method objects being garbage collected.
        # See discussions in https://github.com/vllm-project/vllm/pull/22724
        self.python_malloc_callback = self._python_malloc_callback
        self.python_free_callback = self._python_free_callback
146

147
    def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
148
149
150
151
152
        """
        Internal method to store the allocation data
        when memory is allocated in the memory pool."""
        py_d_mem = allocation_handle[2]
        self.pointer_to_data[py_d_mem] = AllocationData(
153
154
            allocation_handle, self.current_tag
        )
155
156
        logger.debug(
            "Allocated %s bytes for %s with address %s from cumem allocator",
157
158
159
160
            allocation_handle[1],
            self.current_tag,
            py_d_mem,
        )
161
162
        return

163
    def _python_free_callback(self, ptr: int) -> HandleType:
164
165
166
167
168
169
        """
        Internal method to look up the allocation data
        when memory is freed in the memory pool."""
        data = self.pointer_to_data.pop(ptr)
        if data.cpu_backup_tensor is not None:
            data.cpu_backup_tensor = None
170
171
        logger.debug(
            "Freed %s bytes for %s with address %s from cumem allocator",
172
173
174
175
            data.handle[1],
            data.tag,
            ptr,
        )
176
177
        return data.handle

178
    def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
179
180
        """
        Put the allocator in sleep mode.
181
        All data in the memory allocation with the specified tag will be
182
183
184
185
186
187
188
189
        offloaded to CPU memory, and others will be discarded.

        :param offload_tags: The tags of the memory allocation that will be
            offloaded. The rest of the memory allocation will be discarded.
        """
        if offload_tags is None:
            # by default, allocated tensors are offloaded
            # when the allocator sleeps
190
            offload_tags = (CuMemAllocator.default_tag,)
191
        elif isinstance(offload_tags, str):
192
            offload_tags = (offload_tags,)
193
194
195

        assert isinstance(offload_tags, tuple)

196
197
198
        total_bytes = 0
        backup_bytes = 0

199
200
        for ptr, data in self.pointer_to_data.items():
            handle = data.handle
201
            total_bytes += handle[1]
202
            if data.tag in offload_tags:
203
                backup_bytes += handle[1]
204
205
206
207
                size_in_bytes = handle[1]
                cpu_backup_tensor = torch.empty(
                    size_in_bytes,
                    dtype=torch.uint8,
208
209
210
                    device="cpu",
                    pin_memory=is_pin_memory_available(),
                )
211
212
213
214
215
                cpu_ptr = cpu_backup_tensor.data_ptr()
                libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
                data.cpu_backup_tensor = cpu_backup_tensor
            unmap_and_release(handle)

216
217
218
        logger.info(
            "CuMemAllocator: sleep freed %.2f GiB memory in total, of which "
            "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
219
220
221
222
223
            "directly.",
            total_bytes / 1024**3,
            backup_bytes / 1024**3,
            (total_bytes - backup_bytes) / 1024**3,
        )
224

225
226
227
        gc.collect()
        torch.cuda.empty_cache()

228
    def wake_up(self, tags: list[str] | None = None) -> None:
229
230
        """
        Wake up the allocator from sleep mode.
231
        All data that is previously offloaded will be loaded back to GPU
232
        memory, and the rest of the data will have empty memory.
233

234
235
236
237
        :param tags: The tags of the memory allocation that will be loaded
            back to GPU memory. If None, all memory allocation will be loaded
            back to GPU memory.
        """
238
        for ptr, data in self.pointer_to_data.items():
239
240
241
242
243
244
            if tags is None or data.tag in tags:
                handle = data.handle
                create_and_map(handle)
                if data.cpu_backup_tensor is not None:
                    cpu_backup_tensor = data.cpu_backup_tensor
                    if cpu_backup_tensor is not None:
245
246
247
                        size_in_bytes = (
                            cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
                        )
248
249
250
                        cpu_ptr = cpu_backup_tensor.data_ptr()
                        libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
                        data.cpu_backup_tensor = None
251
252

    @contextmanager
253
    def use_memory_pool(self, tag: str | None = None):
254
255
        """
        A context manager to use the memory pool.
256
        All memory allocation created inside the context will be allocated
257
258
259
260
261
262
263
264
265
266
267
268
        in the memory pool, and has the specified tag.

        :param tag: The tag of the memory allocation. If None, the default tag
            will be used.
        """
        if tag is None:
            tag = CuMemAllocator.default_tag

        assert isinstance(tag, str)

        old_tag = self.current_tag
        self.current_tag = tag
269
270
271
        with use_memory_pool_with_allocator(
            self.python_malloc_callback, self.python_free_callback
        ) as data:
272
273
274
275
276
277
            # start to hit another PyTorch bug in PyTorch 2.6,
            # possibly because of gc-related issue w.r.t. the allocator and
            # the memory pool.
            # to avoid the issue, we keep a reference of the data.
            # see https://github.com/pytorch/pytorch/issues/146431 .
            self.allocator_and_pools[tag] = data
278
279
280
281
282
            yield
            # PyTorch's bug, calling torch.cuda.empty_cache() will error
            # when using pluggable allocator, see
            # https://github.com/pytorch/pytorch/issues/145168 .
            # if we have some memory allocated and then freed,
283
284
285
286
287
288
289
290
291
292
293
            # the memory will not be released, e.g. in online quantization,
            # where the model is created in higher precision, and then
            # quantized in lower precision.
            # Find all unused allocations and manually release them.
            # TODO: we should expose `empty_cache` method in the memory pool.
            # TODO: ask for help from PyTorch team to expose this method.
            allocations = data[0].snapshot()
            for allocation in allocations:
                if allocation["allocated_size"] == 0:
                    handle = self._python_free_callback(allocation["address"])
                    unmap_and_release(handle)
294
295
296
297
298
299
300
301
302
303
304
            self.current_tag = old_tag

    def get_current_usage(self) -> int:
        """
        Get the total number of bytes allocated in the memory pool.
        """
        sum_bytes: int = 0
        for ptr, data in self.pointer_to_data.items():
            handle = data.handle
            sum_bytes += handle[1]
        return sum_bytes