Unverified Commit 14df02b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Cleanup `mem_utils.py` (#31793)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6ebb66cc
...@@ -22,7 +22,7 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int: ...@@ -22,7 +22,7 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail # will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero" assert max_shared_mem > 0, "max_shared_mem cannot be zero"
return int(max_shared_mem) return int(max_shared_mem)
...@@ -154,12 +154,16 @@ class MemoryProfilingResult: ...@@ -154,12 +154,16 @@ class MemoryProfilingResult:
non_kv_cache_memory: int = 0 non_kv_cache_memory: int = 0
torch_peak_increase: int = 0 torch_peak_increase: int = 0
non_torch_increase: int = 0 non_torch_increase: int = 0
weights_memory: float = 0 weights_memory: int = 0
before_create: MemorySnapshot = field(default_factory=MemorySnapshot) before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0 profile_time: float = 0.0
def __post_init__(self) -> None:
device = self.before_create.device_
self.before_profile = MemorySnapshot(device=device, auto_measure=False)
self.after_profile = MemorySnapshot(device=device, auto_measure=False)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"Memory profiling takes {self.profile_time:.2f} seconds. " f"Memory profiling takes {self.profile_time:.2f} seconds. "
...@@ -175,9 +179,12 @@ class MemoryProfilingResult: ...@@ -175,9 +179,12 @@ class MemoryProfilingResult:
@contextlib.contextmanager @contextlib.contextmanager
def memory_profiling( def memory_profiling(
baseline_snapshot: MemorySnapshot, weights_memory: int baseline_snapshot: MemorySnapshot,
weights_memory: int = 0,
) -> Generator[MemoryProfilingResult, None, None]: ) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager. """
Memory profiling context manager.
baseline_snapshot: the memory snapshot before the current vLLM instance. baseline_snapshot: the memory snapshot before the current vLLM instance.
weights_memory: memory used by PyTorch when loading the model weights. weights_memory: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device Note that, before loading the model weights, we also initialize the device
...@@ -217,21 +224,24 @@ def memory_profiling( ...@@ -217,21 +224,24 @@ def memory_profiling(
b. 2 GiB reserved for the peak activation tensors (category 2) b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3) c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory`. The memory used for loading weights (a.) is directly given from the
argument `weights_memory`.
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
during profiling gives (b.).
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). The increase of `non_torch_memory` from creating the current vLLM instance
""" # noqa until after profiling to get (c.).
"""
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_)
result = MemoryProfilingResult() result = MemoryProfilingResult(
before_create=baseline_snapshot,
result.before_create = baseline_snapshot
# the part of memory used for holding the model weights # the part of memory used for holding the model weights
result.weights_memory = weights_memory weights_memory=weights_memory,
)
result.before_profile.measure() result.before_profile.measure()
...@@ -252,4 +262,4 @@ def memory_profiling( ...@@ -252,4 +262,4 @@ def memory_profiling(
peak_activation_memory = result.torch_peak_increase peak_activation_memory = result.torch_peak_increase
result.non_kv_cache_memory = ( result.non_kv_cache_memory = (
non_torch_memory + peak_activation_memory + result.weights_memory non_torch_memory + peak_activation_memory + result.weights_memory
) # noqa )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment