Unverified Commit 8c574902 authored by Matt Nappo's avatar Matt Nappo Committed by GitHub
Browse files

[Feature] Option to save model weights to CPU when memory saver mode is enabled (#10873)


Co-authored-by: default avatarmolocule <34072934+molocule@users.noreply.github.com>
parent 34151f17
...@@ -305,6 +305,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -305,6 +305,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--num-continuous-decode-steps` | Run multiple continuous decoding steps to reduce scheduling overhead. This can potentially increase throughput but may also increase time-to-first-token latency. The default value is 1, meaning only run one decoding step at a time. | 1 | | `--num-continuous-decode-steps` | Run multiple continuous decoding steps to reduce scheduling overhead. This can potentially increase throughput but may also increase time-to-first-token latency. The default value is 1, meaning only run one decoding step at a time. | 1 |
| `--delete-ckpt-after-loading` | Delete the model checkpoint after loading the model. | False | | `--delete-ckpt-after-loading` | Delete the model checkpoint after loading the model. | False |
| `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False | | `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False |
| `--enable-weights-cpu-backup` | Save model weights to CPU memory during release_weights_occupation and resume_weights_occupation | False |
| `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False | | `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False |
| `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False | | `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False |
| `--flashinfer-mla-disable-ragged` | Disable ragged processing in Flashinfer MLA. | False | | `--flashinfer-mla-disable-ragged` | Disable ragged processing in Flashinfer MLA. | False |
......
...@@ -58,7 +58,7 @@ dependencies = [ ...@@ -58,7 +58,7 @@ dependencies = [
"tiktoken", "tiktoken",
"timm==1.0.16", "timm==1.0.16",
"torch==2.8.0", "torch==2.8.0",
"torch_memory_saver==0.0.8", "torch_memory_saver==0.0.9rc1",
"torchao==0.9.0", "torchao==0.9.0",
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
......
...@@ -110,7 +110,7 @@ srt_hpu = ["sglang[runtime_common]"] ...@@ -110,7 +110,7 @@ srt_hpu = ["sglang[runtime_common]"]
openai = ["openai==1.99.1", "tiktoken"] openai = ["openai==1.99.1", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver==0.0.8"] torch_memory_saver = ["torch_memory_saver==0.0.9rc1"]
decord = ["decord"] decord = ["decord"]
test = [ test = [
"accelerate", "accelerate",
......
...@@ -25,9 +25,7 @@ import time ...@@ -25,9 +25,7 @@ import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from urllib.parse import urlparse
import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -36,7 +34,6 @@ from sglang.srt.configs.device_config import DeviceConfig ...@@ -36,7 +34,6 @@ from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.connector import ConnectorType
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group, get_pp_group,
...@@ -132,7 +129,6 @@ from sglang.srt.utils import ( ...@@ -132,7 +129,6 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_cpu_ids_by_node, get_cpu_ids_by_node,
init_custom_process_group, init_custom_process_group,
is_blackwell,
is_fa3_default_architecture, is_fa3_default_architecture,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
...@@ -143,7 +139,6 @@ from sglang.srt.utils import ( ...@@ -143,7 +139,6 @@ from sglang.srt.utils import (
log_info_on_rank0, log_info_on_rank0,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
parse_connector_type,
set_cuda_arch, set_cuda_arch,
) )
from sglang.srt.weight_sync.tensor_bucket import ( from sglang.srt.weight_sync.tensor_bucket import (
...@@ -616,7 +611,7 @@ class ModelRunner: ...@@ -616,7 +611,7 @@ class ModelRunner:
server_args.hicache_io_backend = "direct" server_args.hicache_io_backend = "direct"
logger.warning( logger.warning(
"FlashAttention3 decode backend is not compatible with hierarchical cache. " "FlashAttention3 decode backend is not compatible with hierarchical cache. "
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes." "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
) )
def init_torch_distributed(self): def init_torch_distributed(self):
...@@ -778,7 +773,10 @@ class ModelRunner: ...@@ -778,7 +773,10 @@ class ModelRunner:
monkey_patch_vllm_parallel_state() monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer() monkey_patch_isinstance_for_vllm_base_layer()
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS): with self.memory_saver_adapter.region(
GPU_MEMORY_TYPE_WEIGHTS,
enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
):
self.model = get_model( self.model = get_model(
model_config=self.model_config, model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
...@@ -1106,7 +1104,7 @@ class ModelRunner: ...@@ -1106,7 +1104,7 @@ class ModelRunner:
handle.wait() handle.wait()
self.model.load_weights(weights) self.model.load_weights(weights)
return True, f"Succeeded to update parameter online." return True, "Succeeded to update parameter online."
except Exception as e: except Exception as e:
error_msg = ( error_msg = (
...@@ -1749,8 +1747,8 @@ class ModelRunner: ...@@ -1749,8 +1747,8 @@ class ModelRunner:
f"prefill_backend={self.prefill_attention_backend_str}." f"prefill_backend={self.prefill_attention_backend_str}."
) )
logger.warning( logger.warning(
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden." "Warning: Attention backend specified by --attention-backend or default backend might be overridden."
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem." "The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
) )
else: else:
attn_backend = self._get_attention_backend_from_str( attn_backend = self._get_attention_backend_from_str(
......
...@@ -400,6 +400,7 @@ class ServerArgs: ...@@ -400,6 +400,7 @@ class ServerArgs:
num_continuous_decode_steps: int = 1 num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False enable_memory_saver: bool = False
enable_weights_cpu_backup: bool = False
allow_auto_truncate: bool = False allow_auto_truncate: bool = False
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
...@@ -2541,6 +2542,11 @@ class ServerArgs: ...@@ -2541,6 +2542,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Allow saving memory using release_memory_occupation and resume_memory_occupation", help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
) )
parser.add_argument(
"--enable-weights-cpu-backup",
action="store_true",
help="Save model weights to CPU memory during release_weights_occupation and resume_weights_occupation",
)
parser.add_argument( parser.add_argument(
"--allow-auto-truncate", "--allow-auto-truncate",
action="store_true", action="store_true",
......
import logging import logging
import threading
import time
from abc import ABC from abc import ABC
from contextlib import contextmanager, nullcontext from contextlib import contextmanager
try: try:
import torch_memory_saver import torch_memory_saver
...@@ -40,7 +38,7 @@ class TorchMemorySaverAdapter(ABC): ...@@ -40,7 +38,7 @@ class TorchMemorySaverAdapter(ABC):
def configure_subprocess(self): def configure_subprocess(self):
raise NotImplementedError raise NotImplementedError
def region(self, tag: str): def region(self, tag: str, enable_cpu_backup: bool = False):
raise NotImplementedError raise NotImplementedError
def pause(self, tag: str): def pause(self, tag: str):
...@@ -60,8 +58,8 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): ...@@ -60,8 +58,8 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
def configure_subprocess(self): def configure_subprocess(self):
return torch_memory_saver.configure_subprocess() return torch_memory_saver.configure_subprocess()
def region(self, tag: str): def region(self, tag: str, enable_cpu_backup: bool = False):
return _memory_saver.region(tag=tag) return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
def pause(self, tag: str): def pause(self, tag: str):
return _memory_saver.pause(tag=tag) return _memory_saver.pause(tag=tag)
...@@ -80,7 +78,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): ...@@ -80,7 +78,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
yield yield
@contextmanager @contextmanager
def region(self, tag: str): def region(self, tag: str, enable_cpu_backup: bool = False):
yield yield
def pause(self, tag: str): def pause(self, tag: str):
......
...@@ -25,8 +25,6 @@ configurations (tp=1, tp=2) to ensure proper memory management in distributed se ...@@ -25,8 +25,6 @@ configurations (tp=1, tp=2) to ensure proper memory management in distributed se
data parallel size, we test it in verl. data parallel size, we test it in verl.
""" """
import gc
import os
import time import time
import unittest import unittest
...@@ -52,7 +50,14 @@ def get_gpu_memory_gb(): ...@@ -52,7 +50,14 @@ def get_gpu_memory_gb():
class TestReleaseMemoryOccupation(CustomTestCase): class TestReleaseMemoryOccupation(CustomTestCase):
def _setup_engine(self, model_name, mem_fraction_static=0.8, tp_size=1, ep_size=1): def _setup_engine(
self,
model_name,
mem_fraction_static=0.8,
tp_size=1,
ep_size=1,
enable_weights_cpu_backup=False,
):
"""Common setup for engine and HF model.""" """Common setup for engine and HF model."""
engine = sgl.Engine( engine = sgl.Engine(
model_path=model_name, model_path=model_name,
...@@ -61,6 +66,7 @@ class TestReleaseMemoryOccupation(CustomTestCase): ...@@ -61,6 +66,7 @@ class TestReleaseMemoryOccupation(CustomTestCase):
mem_fraction_static=mem_fraction_static, mem_fraction_static=mem_fraction_static,
tp_size=tp_size, tp_size=tp_size,
ep_size=ep_size, ep_size=ep_size,
enable_weights_cpu_backup=enable_weights_cpu_backup,
# disable_cuda_graph=True, # for debugging only # disable_cuda_graph=True, # for debugging only
) )
...@@ -153,6 +159,53 @@ class TestReleaseMemoryOccupation(CustomTestCase): ...@@ -153,6 +159,53 @@ class TestReleaseMemoryOccupation(CustomTestCase):
self.assertEqual(outputs, params["expect_output_after_update_weights"]) self.assertEqual(outputs, params["expect_output_after_update_weights"])
engine.shutdown() engine.shutdown()
def test_release_and_resume_occupation_with_weights_cpu_backup(self):
# Test release and resume occupation with weights CPU backup
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
print("Testing test_release_and_resume_occupation_with_weights_cpu_backup")
engine = self._setup_engine(
model_name=model_name,
mem_fraction_static=0.6,
enable_weights_cpu_backup=True,
)
params = self._common_test_params()
self._test_initial_generation(
engine,
params["prompt"],
params["sampling_params"],
params["expect_output_before_update_weights"],
)
t = time.perf_counter()
gpu_memory_usage_before_release = get_gpu_memory_gb()
engine.release_memory_occupation()
gpu_memory_usage_after_release = get_gpu_memory_gb()
self.assertLess(
gpu_memory_usage_after_release,
gpu_memory_usage_before_release,
)
print(
f"Release took {time.perf_counter() - t:.2f}s, memory: {gpu_memory_usage_before_release:.1f} GB → {gpu_memory_usage_after_release:.1f} GB"
)
if _DEBUG_EXTRA:
time.sleep(3)
t = time.perf_counter()
engine.resume_memory_occupation()
print(
f"Resume took {time.perf_counter() - t:.2f}s, memory: {get_gpu_memory_gb():.1f} GB"
)
print("generate post resume")
outputs = engine.generate(params["prompt"], params["sampling_params"])["text"]
self.assertEqual(outputs, params["expect_output_before_update_weights"])
engine.shutdown()
def test_multi_stage_release_and_resume(self): def test_multi_stage_release_and_resume(self):
# With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM # With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment