Unverified Commit 2d120f8b authored by Zheng Wengang's avatar Zheng Wengang Committed by GitHub
Browse files

[Feature][Multimodal] Implement LRU cache for multimodal embeddings (#8292)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 4f2e1490
...@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding( ...@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
embedding_per_req = data_embedding_func(embedding_items_per_req) embedding_per_req = data_embedding_func(embedding_items_per_req)
if not embedding_cache.put(embedding_items_hash, embedding_per_req): if not embedding_cache.put(embedding_items_hash, embedding_per_req):
print_warning_once( print_warning_once(
"Multimodal embedding cache is full. Consider increasing the " "Multimodal embedding cache is full. This typically occurs when a single "
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable." "embedding exceeds the cache size limit. Consider increasing the "
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
"embedding size."
) )
embedding_per_req_chunk, _, end_index = get_embedding_chunk( embedding_per_req_chunk, _, _ = get_embedding_chunk(
embedding=embedding_per_req, embedding=embedding_per_req,
extend_prefix_len=prefix_length[i], extend_prefix_len=prefix_length[i],
extend_seq_len=extend_length[i] if i < len(extend_length) else 0, extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
items_offset=items_offset, items_offset=items_offset,
) )
# remove this item from cache if chunk reaches to the end
embedding_per_req_length = (
embedding_per_req.shape[0]
if embedding_per_req.dim() == 2
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
)
if end_index == embedding_per_req_length:
embedding_cache.free(embedding_items_hash)
embedding_list.append(embedding_per_req_chunk) embedding_list.append(embedding_per_req_chunk)
if len(embedding_list) == 0: if len(embedding_list) == 0:
return None return None
......
import logging
from collections import OrderedDict
from typing import Dict from typing import Dict
import torch import torch
# Set up logging for cache behavior
logger = logging.getLogger(__name__)
class MultiModalCache: class MultiModalCache:
"""MultiModalCache is used to store vlm encoder results""" """MultiModalCache is used to store vlm encoder results with LRU eviction"""
def __init__( def __init__(
self, self,
max_size: int, max_size: int,
): ):
self.max_size = max_size self.max_size = max_size
self.mm_cache: Dict[int, torch.Tensor] = {} self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
self.current_size = 0 self.current_size = 0
def _allocate(self, embedding_size: int) -> bool:
"""Allocate space by evicting least recently used entries"""
evictions = 0
while self.current_size + embedding_size > self.max_size and self.mm_cache:
_, old_embedding = self.mm_cache.popitem(last=False)
evicted_size = self._get_tensor_size(old_embedding)
self.current_size -= evicted_size
evictions += evicted_size
if evictions > 0:
logger.debug(
f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
)
if self.current_size + embedding_size > self.max_size:
return False
return True
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool: def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
if mm_hash in self.mm_cache:
return True
data_size = self._get_tensor_size(embedding) data_size = self._get_tensor_size(embedding)
if self.current_size + data_size > self.max_size: # Lazy free cache if not enough space
if not self._allocate(data_size):
return False return False
self.mm_cache[mm_hash] = embedding self.mm_cache[mm_hash] = embedding
self.current_size += data_size self.current_size += data_size
...@@ -28,14 +50,12 @@ class MultiModalCache: ...@@ -28,14 +50,12 @@ class MultiModalCache:
return mm_hash in self.mm_cache return mm_hash in self.mm_cache
def get(self, mm_hash: int) -> torch.Tensor: def get(self, mm_hash: int) -> torch.Tensor:
return self.mm_cache.get(mm_hash) """Get embedding and update LRU order"""
if mm_hash in self.mm_cache:
def free(self, mm_hash: int) -> bool: # Move to end (most recently used)
if mm_hash not in self.mm_cache: self.mm_cache.move_to_end(mm_hash)
return False return self.mm_cache[mm_hash]
old_embedding = self.mm_cache.pop(mm_hash) return None
self.current_size -= self._get_tensor_size(old_embedding)
return True
def clear(self): def clear(self):
self.mm_cache.clear() self.mm_cache.clear()
......
...@@ -42,6 +42,21 @@ class TestVLMModels(CustomTestCase): ...@@ -42,6 +42,21 @@ class TestVLMModels(CustomTestCase):
os.environ["OPENAI_API_KEY"] = cls.api_key os.environ["OPENAI_API_KEY"] = cls.api_key
os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1" os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"
def _detect_eviction_in_logs(self, log_output):
"""Detect if eviction events occurred in the log output."""
eviction_keywords = ["Cache eviction: evicted"]
eviction_detected = False
eviction_count = 0
for line in log_output.split("\n"):
if any(keyword in line for keyword in eviction_keywords):
eviction_detected = True
eviction_count += 1
print(f"Eviction detected: {line.strip()}")
return eviction_detected, eviction_count
def run_mmmu_eval( def run_mmmu_eval(
self, self,
model_version: str, model_version: str,
...@@ -91,6 +106,140 @@ class TestVLMModels(CustomTestCase): ...@@ -91,6 +106,140 @@ class TestVLMModels(CustomTestCase):
timeout=3600, timeout=3600,
) )
def _run_vlm_mmmu_test(
self,
model,
output_path,
test_name="",
custom_env=None,
log_level="info",
capture_output=False,
):
"""
Common method to run VLM MMMU benchmark test.
Args:
model: Model to test
output_path: Path for output logs
test_name: Optional test name for logging
custom_env: Optional custom environment variables
log_level: Log level for server (default: "info")
capture_output: Whether to capture server stdout/stderr
"""
print(f"\nTesting model: {model.model}{test_name}")
process = None
mmmu_accuracy = 0 # Initialize to handle potential exceptions
server_output = ""
try:
# Prepare environment variables
process_env = os.environ.copy()
if custom_env:
process_env.update(custom_env)
# Prepare stdout/stderr redirection if needed
stdout_file = None
stderr_file = None
if capture_output:
stdout_file = open("/tmp/server_stdout.log", "w")
stderr_file = open("/tmp/server_stderr.log", "w")
# Launch server for testing
process = popen_launch_server(
model.model,
base_url=self.base_url,
timeout=self.time_out,
api_key=self.api_key,
other_args=[
"--trust-remote-code",
"--cuda-graph-max-bs",
"32",
"--enable-multimodal",
"--mem-fraction-static",
str(self.parsed_args.mem_fraction_static), # Use class variable
"--log-level",
log_level,
],
env=process_env,
return_stdout_stderr=(
(stdout_file, stderr_file) if capture_output else None
),
)
# Run evaluation
self.run_mmmu_eval(model.model, output_path)
# Get the result file
result_file_path = glob.glob(f"{output_path}/*.json")[0]
with open(result_file_path, "r") as f:
result = json.load(f)
print(f"Result{test_name}\n: {result}")
# Process the result
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
print(
f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}"
)
# Capture server output if requested
if capture_output and process:
server_output = self._read_output_from_files()
# Assert performance meets expected threshold
self.assertGreaterEqual(
mmmu_accuracy,
model.mmmu_accuracy,
f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}",
)
return server_output
except Exception as e:
print(f"Error testing {model.model}{test_name}: {e}")
self.fail(f"Test failed for {model.model}{test_name}: {e}")
finally:
# Ensure process cleanup happens regardless of success/failure
if process is not None and process.poll() is None:
print(f"Cleaning up process {process.pid}")
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process: {e}")
# clean up temporary files
if capture_output:
if stdout_file:
stdout_file.close()
if stderr_file:
stderr_file.close()
for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]:
try:
if os.path.exists(filename):
os.remove(filename)
except Exception as e:
print(f"Error removing {filename}: {e}")
def _read_output_from_files(self):
output_lines = []
log_files = [
("/tmp/server_stdout.log", "[STDOUT]"),
("/tmp/server_stderr.log", "[STDERR]"),
]
for filename, tag in log_files:
try:
if os.path.exists(filename):
with open(filename, "r") as f:
for line in f:
output_lines.append(f"{tag} {line.rstrip()}")
except Exception as e:
print(f"Error reading {tag.lower()} file: {e}")
return "\n".join(output_lines)
def test_vlm_mmmu_benchmark(self): def test_vlm_mmmu_benchmark(self):
"""Test VLM models against MMMU benchmark.""" """Test VLM models against MMMU benchmark."""
models_to_test = MODELS models_to_test = MODELS
...@@ -99,60 +248,51 @@ class TestVLMModels(CustomTestCase): ...@@ -99,60 +248,51 @@ class TestVLMModels(CustomTestCase):
models_to_test = [random.choice(MODELS)] models_to_test = [random.choice(MODELS)]
for model in models_to_test: for model in models_to_test:
print(f"\nTesting model: {model.model}") self._run_vlm_mmmu_test(model, "./logs")
process = None def test_vlm_mmmu_benchmark_with_small_cache(self):
mmmu_accuracy = 0 # Initialize to handle potential exceptions """Test VLM models against MMMU benchmark with a small embedding cache to force eviction."""
models_to_test = MODELS
try: if is_in_ci():
# Launch server for testing models_to_test = [random.choice(MODELS)]
process = popen_launch_server(
model.model,
base_url=self.base_url,
timeout=self.time_out,
api_key=self.api_key,
other_args=[
"--trust-remote-code",
"--cuda-graph-max-bs",
"32",
"--enable-multimodal",
"--mem-fraction-static",
str(self.parsed_args.mem_fraction_static), # Use class variable
],
)
# Run evaluation
self.run_mmmu_eval(model.model, "./logs")
# Get the result file
result_file_path = glob.glob("./logs/*.json")[0]
with open(result_file_path, "r") as f:
result = json.load(f)
print(f"Result \n: {result}")
# Process the result
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
print(f"Model {model.model} achieved accuracy: {mmmu_accuracy:.4f}")
# Assert performance meets expected threshold
self.assertGreaterEqual(
mmmu_accuracy,
model.mmmu_accuracy,
f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f})",
)
except Exception as e: for model in models_to_test:
print(f"Error testing {model.model}: {e}") custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"}
self.fail(f"Test failed for {model.model}: {e}")
finally: # Run the test with output capture
# Ensure process cleanup happens regardless of success/failure server_output = self._run_vlm_mmmu_test(
if process is not None and process.poll() is None: model,
print(f"Cleaning up process {process.pid}") "./logs_small_cache",
try: test_name=" with small embedding cache (evict test)",
kill_process_tree(process.pid) custom_env=custom_env,
except Exception as e: log_level="debug", # Enable debug logging for eviction detection
print(f"Error killing process: {e}") capture_output=True, # Capture server output
)
# Print server output for debugging
print("Server output:\n", server_output)
# Analyze server output for eviction events
eviction_detected, eviction_count = self._detect_eviction_in_logs(
server_output
)
# Assert that eviction was detected (since we're using small cache)
self.assertTrue(
eviction_detected,
f"Expected eviction events to be detected with small cache (5MB), but none found. "
f"Cache size may be too large for the workload or eviction logic may not be working. "
f"Total log content length: {len(server_output)} characters",
)
print(
f"Eviction detection summary: {eviction_count} eviction events detected"
)
# Additional assertion: if eviction was detected, the test passed
if eviction_detected:
print("✅ Eviction logic successfully triggered and detected!")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment