cumem.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
# cumem-based pytorch pluggable allocator to implement sleep mode.
# other approaches tried but failed:
# - cuda-python package binding
# - custom libcuda driver ctypes wrapper
# both of them failed because of cuda context mismatch.
# not sure why, they are created from a different context.
# the only successful approach is to call cuda driver API in C.
import dataclasses
12
import gc
13
import os
14
from collections.abc import Callable
15
from contextlib import contextmanager
16
from typing import Any
17
18
19

import torch

20
from vllm.logger import init_logger
21
from vllm.utils.platform_utils import is_pin_memory_available
22
from vllm.utils.system_utils import find_loaded_library
23

24
25
logger = init_logger(__name__)

26
27
28

cumem_available = False
try:
29
30
31
32
33
34
35
    from vllm.cumem_allocator import (
        init_module,
        python_create_and_map,
        python_unmap_and_release,
    )
    from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary

36
37
38
39
    lib_name = find_loaded_library("cumem_allocator")
    libcudart = CudaRTLibrary()
    cumem_available = True
except ModuleNotFoundError:
40
    # only cuda and rocm platforms support cumem allocator
41
42
43
44
45
46
47
48
    init_module = None
    python_create_and_map = None
    python_unmap_and_release = None
    CudaRTLibrary = None
    lib_name = None
    libcudart = None

# py_device, py_alignedSize, py_d_mem, py_p_memHandle
49
HandleType = tuple[int, int, int, int]
50
51
52
53
54
55


@dataclasses.dataclass
class AllocationData:
    handle: HandleType
    tag: str
56
    cpu_backup_tensor: torch.Tensor | None = None
57
58
59
60
61
62
63
64
65
66
67


def create_and_map(allocation_handle: HandleType) -> None:
    python_create_and_map(*allocation_handle)


def unmap_and_release(allocation_handle: HandleType) -> None:
    python_unmap_and_release(*allocation_handle)


def get_pluggable_allocator(
68
    python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
69
70
71
) -> torch.cuda.memory.CUDAPluggableAllocator:
    init_module(python_malloc_fn, python_free_func)
    new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
72
73
        lib_name, "my_malloc", "my_free"
    )
74
75
76
77
78
    return new_alloc


@contextmanager
def use_memory_pool_with_allocator(
79
80
    python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
) -> None:
81
82
83
    new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
    mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
    with torch.cuda.memory.use_mem_pool(mem_pool):
84
        yield mem_pool, new_alloc
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110


class CuMemAllocator:
    """
    A singleton class that manages a memory pool for CUDA tensors.
    The memory in this pool can be offloaded or discarded when the
    allocator sleeps.

    Inside the `use_memory_pool(tag)` context, all tensors created will
    be allocated in the memory pool, and has the same tag as the
    tag passed to the context.

    When we call `sleep`, all tensors with the specified tag will be
    offloaded to CPU memory, and the rest of the tensors will be discarded.
    When we call `wake_up`, all tensors that are previously offloaded
    will be loaded back to GPU memory, and the rest of the tensors will
    have empty memory.

    Why it needs to be a singleton?
    When allocated tensors are garbage collected, PyTorch will call
    the free callback, which will call the `python_free_callback` method.
    The C-extension uses a global variable to store the function of an
    instance of this class. If we create multiple instances of this class,
    the global variable will be overwritten and the free callback will
    not work as expected.
    """
111

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    instance: "CuMemAllocator" = None
    default_tag: str = "default"

    @staticmethod
    def get_instance() -> "CuMemAllocator":
        """
        CuMemAllocator is a singleton class.
        We cannot call the constructor directly.
        Call this method to get the instance.
        """
        assert cumem_available, "cumem allocator is not available"
        if CuMemAllocator.instance is None:
            CuMemAllocator.instance = CuMemAllocator()
        return CuMemAllocator.instance

    def __init__(self):
128
        conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
129
130
        assert "expandable_segments:True" not in conf, (
            "Expandable segments are not compatible with memory pool. "
131
            "Please track https://github.com/pytorch/pytorch/issues/147851 "
132
133
            "for the latest updates."
        )
134

135
        self.pointer_to_data: dict[int, AllocationData] = {}
136
        self.current_tag: str = CuMemAllocator.default_tag
137
        self.allocator_and_pools: dict[str, Any] = {}
138
139
140
141
142
        # Creating strong references to the two callbacks here to prevent
        # these ephemeral bound-method objects being garbage collected.
        # See discussions in https://github.com/vllm-project/vllm/pull/22724
        self.python_malloc_callback = self._python_malloc_callback
        self.python_free_callback = self._python_free_callback
143

144
    def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
145
146
147
148
149
        """
        Internal method to store the allocation data
        when memory is allocated in the memory pool."""
        py_d_mem = allocation_handle[2]
        self.pointer_to_data[py_d_mem] = AllocationData(
150
151
            allocation_handle, self.current_tag
        )
152
153
        logger.debug(
            "Allocated %s bytes for %s with address %s from cumem allocator",
154
155
156
157
            allocation_handle[1],
            self.current_tag,
            py_d_mem,
        )
158
159
        return

160
    def _python_free_callback(self, ptr: int) -> HandleType:
161
162
163
164
165
166
        """
        Internal method to look up the allocation data
        when memory is freed in the memory pool."""
        data = self.pointer_to_data.pop(ptr)
        if data.cpu_backup_tensor is not None:
            data.cpu_backup_tensor = None
167
168
        logger.debug(
            "Freed %s bytes for %s with address %s from cumem allocator",
169
170
171
172
            data.handle[1],
            data.tag,
            ptr,
        )
173
174
        return data.handle

175
    def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
176
177
        """
        Put the allocator in sleep mode.
178
        All data in the memory allocation with the specified tag will be
179
180
181
182
183
184
185
186
        offloaded to CPU memory, and others will be discarded.

        :param offload_tags: The tags of the memory allocation that will be
            offloaded. The rest of the memory allocation will be discarded.
        """
        if offload_tags is None:
            # by default, allocated tensors are offloaded
            # when the allocator sleeps
187
            offload_tags = (CuMemAllocator.default_tag,)
188
        elif isinstance(offload_tags, str):
189
            offload_tags = (offload_tags,)
190
191
192

        assert isinstance(offload_tags, tuple)

193
194
195
        total_bytes = 0
        backup_bytes = 0

196
197
        for ptr, data in self.pointer_to_data.items():
            handle = data.handle
198
            total_bytes += handle[1]
199
            if data.tag in offload_tags:
200
                backup_bytes += handle[1]
201
202
203
204
                size_in_bytes = handle[1]
                cpu_backup_tensor = torch.empty(
                    size_in_bytes,
                    dtype=torch.uint8,
205
206
207
                    device="cpu",
                    pin_memory=is_pin_memory_available(),
                )
208
209
210
211
212
                cpu_ptr = cpu_backup_tensor.data_ptr()
                libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
                data.cpu_backup_tensor = cpu_backup_tensor
            unmap_and_release(handle)

213
214
215
        logger.info(
            "CuMemAllocator: sleep freed %.2f GiB memory in total, of which "
            "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
216
217
218
219
220
            "directly.",
            total_bytes / 1024**3,
            backup_bytes / 1024**3,
            (total_bytes - backup_bytes) / 1024**3,
        )
221

222
223
224
        gc.collect()
        torch.cuda.empty_cache()

225
    def wake_up(self, tags: list[str] | None = None) -> None:
226
227
        """
        Wake up the allocator from sleep mode.
228
        All data that is previously offloaded will be loaded back to GPU
229
        memory, and the rest of the data will have empty memory.
230

231
232
233
234
        :param tags: The tags of the memory allocation that will be loaded
            back to GPU memory. If None, all memory allocation will be loaded
            back to GPU memory.
        """
235
        for ptr, data in self.pointer_to_data.items():
236
237
238
239
240
241
            if tags is None or data.tag in tags:
                handle = data.handle
                create_and_map(handle)
                if data.cpu_backup_tensor is not None:
                    cpu_backup_tensor = data.cpu_backup_tensor
                    if cpu_backup_tensor is not None:
242
243
244
                        size_in_bytes = (
                            cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
                        )
245
246
247
                        cpu_ptr = cpu_backup_tensor.data_ptr()
                        libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
                        data.cpu_backup_tensor = None
248
249

    @contextmanager
250
    def use_memory_pool(self, tag: str | None = None):
251
252
        """
        A context manager to use the memory pool.
253
        All memory allocation created inside the context will be allocated
254
255
256
257
258
259
260
261
262
263
264
265
        in the memory pool, and has the specified tag.

        :param tag: The tag of the memory allocation. If None, the default tag
            will be used.
        """
        if tag is None:
            tag = CuMemAllocator.default_tag

        assert isinstance(tag, str)

        old_tag = self.current_tag
        self.current_tag = tag
266
267
268
        with use_memory_pool_with_allocator(
            self.python_malloc_callback, self.python_free_callback
        ) as data:
269
270
271
272
273
274
            # start to hit another PyTorch bug in PyTorch 2.6,
            # possibly because of gc-related issue w.r.t. the allocator and
            # the memory pool.
            # to avoid the issue, we keep a reference of the data.
            # see https://github.com/pytorch/pytorch/issues/146431 .
            self.allocator_and_pools[tag] = data
275
276
277
278
279
            yield
            # PyTorch's bug, calling torch.cuda.empty_cache() will error
            # when using pluggable allocator, see
            # https://github.com/pytorch/pytorch/issues/145168 .
            # if we have some memory allocated and then freed,
280
281
282
283
284
285
286
287
288
289
290
            # the memory will not be released, e.g. in online quantization,
            # where the model is created in higher precision, and then
            # quantized in lower precision.
            # Find all unused allocations and manually release them.
            # TODO: we should expose `empty_cache` method in the memory pool.
            # TODO: ask for help from PyTorch team to expose this method.
            allocations = data[0].snapshot()
            for allocation in allocations:
                if allocation["allocated_size"] == 0:
                    handle = self._python_free_callback(allocation["address"])
                    unmap_and_release(handle)
291
292
293
294
295
296
297
298
299
300
301
            self.current_tag = old_tag

    def get_current_usage(self) -> int:
        """
        Get the total number of bytes allocated in the memory pool.
        """
        sum_bytes: int = 0
        for ptr, data in self.pointer_to_data.items():
            handle = data.handle
            sum_bytes += handle[1]
        return sum_bytes