Unverified Commit 385da2da authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Measure model memory usage (#3120)

parent 2daf23ab
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import socket import socket
import subprocess import subprocess
import uuid import uuid
import gc
from platform import uname from platform import uname
from typing import List, Tuple, Union from typing import List, Tuple, Union
from packaging.version import parse, Version from packaging.version import parse, Version
...@@ -309,3 +310,27 @@ def create_kv_caches_with_random( ...@@ -309,3 +310,27 @@ def create_kv_caches_with_random(
f"Does not support value cache of type {cache_dtype}") f"Does not support value cache of type {cache_dtype}")
value_caches.append(value_cache) value_caches.append(value_cache)
return key_caches, value_caches return key_caches, value_caches
class measure_cuda_memory:
def __init__(self, device=None):
self.device = device
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
return mem
def __enter__(self):
self.initial_memory = self.current_memory_usage()
# This allows us to call methods of the context manager if needed
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.final_memory = self.current_memory_usage()
self.consumed_memory = self.final_memory - self.initial_memory
# Force garbage collection
gc.collect()
...@@ -21,7 +21,7 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata ...@@ -21,7 +21,7 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl from vllm.utils import in_wsl, measure_cuda_memory
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -85,11 +85,17 @@ class ModelRunner: ...@@ -85,11 +85,17 @@ class ModelRunner:
self.model_config.enforce_eager = True self.model_config.enforce_eager = True
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config, with measure_cuda_memory() as m:
self.device_config, self.model = get_model(self.model_config,
lora_config=self.lora_config, self.device_config,
parallel_config=self.parallel_config, lora_config=self.lora_config,
scheduler_config=self.scheduler_config) parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
self.model_memory_usage = m.consumed_memory
logger.info(
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
)
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment