mem_utils.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import gc
import time
from collections.abc import Generator
from dataclasses import dataclass, field
from functools import cache

import psutil
import torch
import torch.types

14
15
from vllm.platforms import current_platform

16
from .mem_constants import GiB_bytes, MiB_bytes
17
18


19
20
21
22
23
24
def format_mib(b: int) -> str:
    return f"{round(b / MiB_bytes, 2)}"


def format_gib(b: int) -> str:
    return f"{round(b / GiB_bytes, 2)}"
25
26


27
28
29
30
31
32
33
34
@cache
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
    from vllm import _custom_ops as ops

    max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
35
    assert max_shared_mem > 0, "max_shared_mem cannot be zero"
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    return int(max_shared_mem)


def get_cpu_memory() -> int:
    """Returns the total CPU memory of the node in bytes."""
    return psutil.virtual_memory().total


class DeviceMemoryProfiler:
    def __init__(self, device: torch.types.Device | None = None):
        self.device = device

    def current_memory_usage(self) -> float:
        # Return the memory usage in bytes.
        gc.collect()
        return current_platform.get_current_memory_usage(self.device)

    def __enter__(self):
        self.initial_memory = self.current_memory_usage()
        # This allows us to call methods of the context manager if needed
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.final_memory = self.current_memory_usage()
        self.consumed_memory = self.final_memory - self.initial_memory

        # Force garbage collection
        gc.collect()


@dataclass
class MemorySnapshot:
    """Memory snapshot."""

    torch_peak: int = 0
    free_memory: int = 0
    total_memory: int = 0
    cuda_memory: int = 0
    torch_memory: int = 0
    non_torch_memory: int = 0
    timestamp: float = 0.0
77
78

    device: torch.types.Device = None
79
80
    auto_measure: bool = True

81
    def __post_init__(self) -> None:
82
83
84
85
86
87
88
        if self.device is None:
            device_fn = current_platform.current_device
            assert device_fn is not None
            self.device_ = torch.device(device_fn())
        else:
            self.device_ = torch.device(self.device)

89
90
91
        if self.auto_measure:
            self.measure()

92
    def measure(self) -> None:
93
94
        device = self.device_

95
96
97
98
        # we measure the torch peak memory usage via allocated_bytes,
        # rather than `torch.cuda.memory_reserved()` .
        # After `torch.cuda.reset_peak_memory_stats()`,
        # `torch.cuda.memory_reserved()` will keep growing, and only shrink
99
        # when we call `torch.accelerator.empty_cache()` or OOM happens.
100
        self.torch_peak = current_platform.memory_stats(device).get(
101
102
            "allocated_bytes.all.peak", 0
        )
103

104
        self.free_memory, self.total_memory = current_platform.mem_get_info(device)
105
106
107
        shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1))  # Orin, Thor, Spark
        if (
            current_platform.is_cuda()
108
109
            and current_platform.get_device_capability(device.index)
            in shared_sysmem_device_mem_sms
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        ):
            # On UMA (Orin, Thor and Spark) platform,
            # where both CPU and GPU rely on system memory,
            # the cudaMemGetInfo function shows the amount of free system memory
            # rather than what’s actually available.
            # In the case,
            # torch.cuda.mem_get_info() only reports "free" memory,
            # which can be lower than what is actually
            # available due to not including cache memory.
            # There’s also a comprehensive reference page
            # that explains how you can compute the proper value yourself.
            # https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
            self.free_memory = psutil.virtual_memory().available

        self.cuda_memory = self.total_memory - self.free_memory

        # torch.cuda.memory_reserved() is how many bytes
        # PyTorch gets from cuda (by calling cudaMalloc, etc.)
        # this is used to measure the non-torch memory usage
129
        self.torch_memory = current_platform.memory_reserved(device)
130
131
132
133
134

        self.non_torch_memory = self.cuda_memory - self.torch_memory
        self.timestamp = time.time()

    def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
135
136
137
138
139
140
        if self.device_ != other.device_:
            raise ValueError(
                "The two snapshots should be from the same device! "
                f"Found: {self.device_} vs. {other.device_}"
            )

141
142
143
144
145
146
147
148
        return MemorySnapshot(
            torch_peak=self.torch_peak - other.torch_peak,
            free_memory=self.free_memory - other.free_memory,
            total_memory=self.total_memory - other.total_memory,
            cuda_memory=self.cuda_memory - other.cuda_memory,
            torch_memory=self.torch_memory - other.torch_memory,
            non_torch_memory=self.non_torch_memory - other.non_torch_memory,
            timestamp=self.timestamp - other.timestamp,
149
            device=self.device_,
150
151
152
            auto_measure=False,
        )

153
154
155
156
157
    def __repr__(self) -> str:
        return (
            f"torch_peak={format_gib(self.torch_peak)}GiB, "
            f"free_memory={format_gib(self.free_memory)}GiB, "
            f"total_memory={format_gib(self.total_memory)}GiB, "
158
            f"{current_platform.device_name}_memory={format_gib(self.cuda_memory)}GiB, "
159
160
161
162
163
164
            f"torch_memory={format_gib(self.torch_memory)}GiB, "
            f"non_torch_memory={format_gib(self.non_torch_memory)}GiB, "
            f"timestamp={self.timestamp}, "
            f"auto_measure={self.auto_measure}"
        )

165
166
167
168
169
170
171
172

@dataclass
class MemoryProfilingResult:
    """Memory profiling result. All numbers are in bytes."""

    non_kv_cache_memory: int = 0
    torch_peak_increase: int = 0
    non_torch_increase: int = 0
173
    weights_memory: int = 0
174
175
176
    before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
    profile_time: float = 0.0

177
178
179
180
181
182
    def __post_init__(self) -> None:
        device = self.before_create.device_

        self.before_profile = MemorySnapshot(device=device, auto_measure=False)
        self.after_profile = MemorySnapshot(device=device, auto_measure=False)

183
184
185
186
    def __repr__(self) -> str:
        return (
            f"Memory profiling takes {self.profile_time:.2f} seconds. "
            f"Total non KV cache memory: "
187
            f"{format_gib(self.non_kv_cache_memory)}GiB; "
188
            f"torch peak memory increase: "
189
            f"{format_gib(self.torch_peak_increase)}GiB; "
190
            f"non-torch forward increase memory: "
191
192
            f"{format_gib(self.non_torch_increase)}GiB; "
            f"weights memory: {format_gib(self.weights_memory)}GiB."
193
194
195
196
197
        )


@contextlib.contextmanager
def memory_profiling(
198
199
    baseline_snapshot: MemorySnapshot,
    weights_memory: int = 0,
200
) -> Generator[MemoryProfilingResult, None, None]:
201
202
203
    """
    Memory profiling context manager.

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    baseline_snapshot: the memory snapshot before the current vLLM instance.
    weights_memory: memory used by PyTorch when loading the model weights.
        Note that, before loading the model weights, we also initialize the device
        and distributed environment, which may consume some memory. This part is not
        included in the weights_memory because PyTorch does not control it.

    The memory in one GPU can be classified into 3 categories:
    1. memory used by anything other than the current vLLM instance.
    2. memory used by torch in the current vLLM instance.
    3. memory used in the current vLLM instance, but not by torch.

    A quantitive example:

    Before creating the current vLLM instance:
        category 1: 1 GiB
        category 2: 0 GiB
        category 3: 0 GiB

    After creating the current vLLM instance and loading the model,
    (i.e. before profiling):
        category 1: 1 GiB
        category 2: 2 GiB (model weights take 2 GiB)
        category 3: 0.5 GiB (memory used by NCCL)

    During profiling (peak):
        category 1: 1 GiB
        category 2: 4 GiB (peak activation tensors take 2 GiB)
        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

    After profiling:
        category 1: 1 GiB
        category 2: 3 GiB (after garbage-collecting activation tensors)
        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

    In this case, non-kv cache takes 5 GiB in total, including:
    a. 2 GiB used by the model weights (category 2)
    b. 2 GiB reserved for the peak activation tensors (category 2)
    c. 1 GiB used by non-torch components (category 3)

243
244
    The memory used for loading weights (a.) is directly given from the
    argument `weights_memory`.
245

246
247
    The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
    during profiling gives (b.).
248

249
250
251
    The increase of `non_torch_memory` from creating the current vLLM instance
    until after profiling to get (c.).
    """
252
    gc.collect()
253
    torch.accelerator.empty_cache()
254
    current_platform.reset_peak_memory_stats(baseline_snapshot.device_)
255

256
257
258
259
260
    result = MemoryProfilingResult(
        before_create=baseline_snapshot,
        # the part of memory used for holding the model weights
        weights_memory=weights_memory,
    )
261
262
263
264
265
266

    result.before_profile.measure()

    yield result

    gc.collect()
267
    torch.accelerator.empty_cache()
268
269
270
271
272
273
274
275
276
277
278
279
280

    result.after_profile.measure()

    diff_profile = result.after_profile - result.before_profile
    diff_from_create = result.after_profile - result.before_create
    result.torch_peak_increase = diff_profile.torch_peak
    result.non_torch_increase = diff_from_create.non_torch_memory
    result.profile_time = diff_profile.timestamp

    non_torch_memory = result.non_torch_increase
    peak_activation_memory = result.torch_peak_increase
    result.non_kv_cache_memory = (
        non_torch_memory + peak_activation_memory + result.weights_memory
281
    )