Unverified Commit 3774f078 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)

parent 9179ea15
...@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"] ...@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver>=0.0.4"] torch_memory_saver = ["torch_memory_saver>=0.0.8"]
decord = ["decord"] decord = ["decord"]
test = [ test = [
"accelerate", "accelerate",
......
# GPU Memory Types
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
GPU_MEMORY_TYPE_WEIGHTS = "weights"
...@@ -31,6 +31,7 @@ import numpy as np ...@@ -31,6 +31,7 @@ import numpy as np
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST, FAKE_BOOTSTRAP_HOST,
...@@ -90,7 +91,7 @@ class DecodeReqToTokenPool: ...@@ -90,7 +91,7 @@ class DecodeReqToTokenPool:
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.device = device self.device = device
self.pre_alloc_size = pre_alloc_size self.pre_alloc_size = pre_alloc_size
with memory_saver_adapter.region(): with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros( self.req_to_token = torch.zeros(
(size + pre_alloc_size, max_context_len), (size + pre_alloc_size, max_context_len),
dtype=torch.int32, dtype=torch.int32,
......
...@@ -479,17 +479,15 @@ class Engine(EngineBase): ...@@ -479,17 +479,15 @@ class Engine(EngineBase):
self.tokenizer_manager.get_weights_by_name(obj, None) self.tokenizer_manager.get_weights_by_name(obj, None)
) )
def release_memory_occupation(self): def release_memory_occupation(self, tags: Optional[List[str]] = None):
"""Release GPU occupation temporarily.""" obj = ReleaseMemoryOccupationReqInput(tags=tags)
obj = ReleaseMemoryOccupationReqInput()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.tokenizer_manager.release_memory_occupation(obj, None) self.tokenizer_manager.release_memory_occupation(obj, None)
) )
def resume_memory_occupation(self): def resume_memory_occupation(self, tags: Optional[List[str]] = None):
"""Resume GPU occupation.""" obj = ResumeMemoryOccupationReqInput(tags=tags)
obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.tokenizer_manager.resume_memory_occupation(obj, None) self.tokenizer_manager.resume_memory_occupation(obj, None)
...@@ -670,11 +668,9 @@ def _launch_subprocesses( ...@@ -670,11 +668,9 @@ def _launch_subprocesses(
scheduler_procs = [] scheduler_procs = []
if server_args.dp_size == 1: if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
memory_saver_adapter = TorchMemorySaverAdapter.create( memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver enable=server_args.enable_memory_saver
) )
scheduler_pipe_readers = [] scheduler_pipe_readers = []
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
...@@ -710,6 +706,7 @@ def _launch_subprocesses( ...@@ -710,6 +706,7 @@ def _launch_subprocesses(
writer, writer,
), ),
) )
with memory_saver_adapter.configure_subprocess(): with memory_saver_adapter.configure_subprocess():
proc.start() proc.start()
scheduler_procs.append(proc) scheduler_procs.append(proc)
......
...@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput: ...@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:
@dataclass @dataclass
class ReleaseMemoryOccupationReqInput: class ReleaseMemoryOccupationReqInput:
pass # Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None
@dataclass @dataclass
...@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput: ...@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:
@dataclass @dataclass
class ResumeMemoryOccupationReqInput: class ResumeMemoryOccupationReqInput:
pass # Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags: Optional[List[str]] = None
@dataclass @dataclass
......
...@@ -36,6 +36,7 @@ from torch.distributed import barrier ...@@ -36,6 +36,7 @@ from torch.distributed import barrier
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ, INVALID_GRAMMAR_OBJ,
create_grammar_backend, create_grammar_backend,
...@@ -450,8 +451,6 @@ class Scheduler( ...@@ -450,8 +451,6 @@ class Scheduler(
t = threading.Thread(target=self.watchdog_thread, daemon=True) t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start() t.start()
self.parent_process = psutil.Process().parent() self.parent_process = psutil.Process().parent()
# Init memory saver
self.memory_saver_adapter = TorchMemorySaverAdapter.create( self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver enable=server_args.enable_memory_saver
) )
...@@ -2227,23 +2226,40 @@ class Scheduler( ...@@ -2227,23 +2226,40 @@ class Scheduler(
return GetWeightsByNameReqOutput(parameter) return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity( tags = recv_req.tags
caller_name="release_memory_occupation" import subprocess
)
self.stashed_model_static_state = _export_static_state( if tags is None:
self.tp_worker.worker.model_runner.model tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
)
self.memory_saver_adapter.pause() if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.flush_cache() self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput() return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput): def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation") tags = recv_req.tags
self.memory_saver_adapter.resume() if tags is None or len(tags) == 0:
_import_static_state( tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
) if GPU_MEMORY_TYPE_WEIGHTS in tags:
del self.stashed_model_static_state self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state
if GPU_MEMORY_TYPE_KV_CACHE in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
return ResumeMemoryOccupationReqOutput() return ResumeMemoryOccupationReqOutput()
def slow_down(self, recv_req: SlowDownReqInput): def slow_down(self, recv_req: SlowDownReqInput):
......
...@@ -35,6 +35,7 @@ import torch ...@@ -35,6 +35,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2 from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
...@@ -54,6 +55,7 @@ class ReqToTokenPool: ...@@ -54,6 +55,7 @@ class ReqToTokenPool:
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
): ):
memory_saver_adapter = TorchMemorySaverAdapter.create( memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver enable=enable_memory_saver
) )
...@@ -61,7 +63,7 @@ class ReqToTokenPool: ...@@ -61,7 +63,7 @@ class ReqToTokenPool:
self.size = size self.size = size
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.device = device self.device = device
with memory_saver_adapter.region(): with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros( self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device (size, max_context_len), dtype=torch.int32, device=device
) )
...@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache): ...@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
) )
def _create_buffers(self): def _create_buffers(self):
with self.memory_saver_adapter.region(): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with ( with (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool if self.enable_custom_mem_pool
...@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache): ...@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
else: else:
self.custom_mem_pool = None self.custom_mem_pool = None
with self.memory_saver_adapter.region(): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with ( with (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool if self.custom_mem_pool
...@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
end_layer, end_layer,
) )
with self.memory_saver_adapter.region(): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
self.k_buffer = [ self.k_buffer = [
torch.zeros( torch.zeros(
......
...@@ -30,6 +30,7 @@ from sglang.srt import debug_utils ...@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tp_group, get_tp_group,
get_world_group, get_world_group,
...@@ -222,6 +223,7 @@ class ModelRunner: ...@@ -222,6 +223,7 @@ class ModelRunner:
def initialize(self, min_per_gpu_memory: float): def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args server_args = self.server_args
self.memory_saver_adapter = TorchMemorySaverAdapter.create( self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver enable=self.server_args.enable_memory_saver
) )
...@@ -547,7 +549,7 @@ class ModelRunner: ...@@ -547,7 +549,7 @@ class ModelRunner:
monkey_patch_vllm_parallel_state() monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer() monkey_patch_isinstance_for_vllm_base_layer()
with self.memory_saver_adapter.region(): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
self.model = get_model( self.model = get_model(
model_config=self.model_config, model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
......
import logging import logging
import threading
import time
from abc import ABC from abc import ABC
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
try: try:
import torch_memory_saver import torch_memory_saver
_primary_memory_saver = torch_memory_saver.TorchMemorySaver() _memory_saver = torch_memory_saver.torch_memory_saver
import_error = None import_error = None
except ImportError as e: except ImportError as e:
import_error = e import_error = e
...@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC): ...@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
def configure_subprocess(self): def configure_subprocess(self):
raise NotImplementedError raise NotImplementedError
def region(self): def region(self, tag: str):
raise NotImplementedError raise NotImplementedError
def pause(self): def pause(self, tag: str):
raise NotImplementedError raise NotImplementedError
def resume(self): def resume(self, tag: str):
raise NotImplementedError raise NotImplementedError
@property @property
...@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC): ...@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
"""Adapter for TorchMemorySaver with tag-based control"""
def configure_subprocess(self): def configure_subprocess(self):
return torch_memory_saver.configure_subprocess() return torch_memory_saver.configure_subprocess()
def region(self): def region(self, tag: str):
return _primary_memory_saver.region() return _memory_saver.region(tag=tag)
def pause(self): def pause(self, tag: str):
return _primary_memory_saver.pause() return _memory_saver.pause(tag=tag)
def resume(self): def resume(self, tag: str):
return _primary_memory_saver.resume() return _memory_saver.resume(tag=tag)
@property @property
def enabled(self): def enabled(self):
return _primary_memory_saver.enabled return _memory_saver is not None and _memory_saver.enabled
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
...@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): ...@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
yield yield
@contextmanager @contextmanager
def region(self): def region(self, tag: str):
yield yield
def pause(self): def pause(self, tag: str):
pass pass
def resume(self): def resume(self, tag: str):
pass pass
@property @property
......
...@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback ...@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback
# General test models # General test models
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B" DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
......
...@@ -74,7 +74,6 @@ suites = { ...@@ -74,7 +74,6 @@ suites = {
TestFile("test_radix_attention.py", 105), TestFile("test_radix_attention.py", 105),
TestFile("test_reasoning_content.py", 89), TestFile("test_reasoning_content.py", 89),
TestFile("test_regex_constrained.py", 64), TestFile("test_regex_constrained.py", 64),
TestFile("test_release_memory_occupation.py", 44),
TestFile("test_request_length_validation.py", 31), TestFile("test_request_length_validation.py", 31),
TestFile("test_retract_decode.py", 54), TestFile("test_retract_decode.py", 54),
TestFile("test_server_args.py", 1), TestFile("test_server_args.py", 1),
...@@ -146,6 +145,7 @@ suites = { ...@@ -146,6 +145,7 @@ suites = {
TestFile("test_patch_torch.py", 19), TestFile("test_patch_torch.py", 19),
TestFile("test_update_weights_from_distributed.py", 103), TestFile("test_update_weights_from_distributed.py", 103),
TestFile("test_verl_engine_2_gpu.py", 64), TestFile("test_verl_engine_2_gpu.py", 64),
TestFile("test_release_memory_occupation.py", 44),
], ],
"per-commit-2-gpu-amd": [ "per-commit-2-gpu-amd": [
TestFile("models/lora/test_lora_tp.py", 116), TestFile("models/lora/test_lora_tp.py", 116),
......
"""Test memory release and resume operations for SGLang engine in hybrid RL training.
This test suite evaluates the SGLang engine's memory management capabilities, focusing
on releasing and resuming memory occupation for KV cache and model weights. It simulates
an RL workflow where the SGLang engine acts as a rollout engine for experience collection.
The process involves initializing the engine, sending a small number of requests to simulate
rollout, releasing memory to mimic offloading during RL training, resuming memory occupation,
updating weights with a trained HuggingFace model, and verifying the updated weights.
Detailed in our proposal (https://github.com/sgl-project/sglang/pull/7099), two test cases
are included:
1. Basic Release and Resume: Uses a lower mem_fraction_static (0.6) to control memory allocation
and avoid OOM errors carefully. This test simulates a scenario without multi-stage memory management,
ensuring the engine can release and resume memory occupation while maintaining functionality after
weight updates.
2. Multi-Stage Release and Resume: Employs a higher mem_fraction_static (0.85) to simulate higher
memory pressure, leveraging multi-stage memory management. It sequentially releases and resumes
KV cache and model weights, verifying memory deallocation and reallocation at each stage, and
ensuring correct weight updates and text generation.
3. Tensor Parallel Tests: Tests memory release and resume operations with different tensor parallel
configurations (tp=1, tp=2) to ensure proper memory management in distributed settings. For different
data parallel size, we test it in verl.
"""
import gc
import os
import time import time
import unittest import unittest
...@@ -5,93 +34,221 @@ import torch ...@@ -5,93 +34,221 @@ import torch
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
CustomTestCase,
)
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly # (temporarily) set to true to observe memory usage in nvidia-smi more clearly
_DEBUG_EXTRA = True _DEBUG_EXTRA = False
class TestReleaseMemoryOccupation(CustomTestCase): def get_gpu_memory_gb():
def test_release_and_resume_occupation(self): return torch.cuda.device_memory_used() / 1024**3
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
expect_output = " to spend it outdoors. I decided to"
class TestReleaseMemoryOccupation(CustomTestCase):
def _setup_engine(self, model_name, mem_fraction_static=0.8, tp_size=1):
"""Common setup for engine and HF model."""
engine = sgl.Engine( engine = sgl.Engine(
model_path=model_name, model_path=model_name,
random_seed=42, random_seed=42,
enable_memory_saver=True, enable_memory_saver=True,
mem_fraction_static=mem_fraction_static,
tp_size=tp_size,
# disable_cuda_graph=True, # for debugging only # disable_cuda_graph=True, # for debugging only
) )
hf_model_new = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="bfloat16"
)
return engine
def _common_test_params(self):
"""Common test parameters."""
return {
"prompt": "Today is a sunny day and I like",
"sampling_params": {"temperature": 0, "max_new_tokens": 8},
"expect_output_before_update_weights": " to spend it outdoors. I decided to",
"expect_output_after_update_weights": " to go for a walk. I like",
}
def _test_initial_generation(
self, engine, prompt, sampling_params, expect_output_before_update_weights
):
"""Test initial generation and memory allocation."""
print("generate (#1)") print("generate (#1)")
outputs = engine.generate(prompt, sampling_params)["text"] outputs = engine.generate(prompt, sampling_params)["text"]
self.assertEqual(outputs, expect_output) self.assertEqual(outputs, expect_output_before_update_weights)
if _DEBUG_EXTRA: if _DEBUG_EXTRA:
time.sleep(3) time.sleep(3)
self.assertEqual( def test_release_and_resume_occupation(self):
_try_allocate_big_tensor(), # Without multi-stage release and resume, we need to carefully control the memory fraction to avoid OOM
False, model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
"Should not be able to allocate big tensors before releasing", assert (
) torch.cuda.device_count() >= 2
), "Need at least 2 GPUs for tensor parallel tests"
print("release_memory_occupation start")
t = time.perf_counter() for tp_size in [1, 2]:
engine.release_memory_occupation()
if _DEBUG_EXTRA: print(f"Testing tp_size={tp_size} for test_release_and_resume_occupation")
print("release_memory_occupation", time.perf_counter() - t) engine = self._setup_engine(
model_name=model_name, mem_fraction_static=0.6, tp_size=tp_size
if _DEBUG_EXTRA: )
time.sleep(5) params = self._common_test_params()
self.assertEqual( self._test_initial_generation(
_try_allocate_big_tensor(), engine,
True, params["prompt"],
"Should be able to allocate big tensors aftre releasing", params["sampling_params"],
) params["expect_output_before_update_weights"],
)
if _DEBUG_EXTRA:
time.sleep(5) t = time.perf_counter()
gpu_memory_usage_before_release = get_gpu_memory_gb()
print("resume_memory_occupation start") engine.release_memory_occupation()
t = time.perf_counter() gpu_memory_usage_after_release = get_gpu_memory_gb()
engine.resume_memory_occupation()
if _DEBUG_EXTRA: self.assertLess(
print("resume_memory_occupation", time.perf_counter() - t) gpu_memory_usage_after_release,
gpu_memory_usage_before_release,
self.assertEqual( )
_try_allocate_big_tensor(),
False, print(
"Should not be able to allocate big tensors after resuming", f"Release took {time.perf_counter() - t:.2f}s, memory: {gpu_memory_usage_before_release:.1f} GB → {gpu_memory_usage_after_release:.1f} GB"
) )
print("update_weights_from_tensor") if _DEBUG_EXTRA:
# As if: PPO has updated hf model's weights, and now we sync it to SGLang time.sleep(3)
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
t = time.perf_counter()
print("generate (#2)") engine.resume_memory_occupation()
outputs = engine.generate(prompt, sampling_params)["text"] print(
self.assertEqual(outputs, expect_output) f"Resume took {time.perf_counter() - t:.2f}s, memory: {get_gpu_memory_gb():.1f} GB"
)
if _DEBUG_EXTRA:
time.sleep(4) hf_model_new = AutoModelForCausalLM.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
engine.shutdown() torch_dtype="bfloat16",
device_map="cuda",
)
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
# destroy the hf model
del hf_model_new
torch.cuda.empty_cache()
print("generate (#2)")
outputs = engine.generate(params["prompt"], params["sampling_params"])[
"text"
]
self.assertEqual(outputs, params["expect_output_after_update_weights"])
engine.shutdown()
def test_multi_stage_release_and_resume(self):
# With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
def _try_allocate_big_tensor(size: int = 20_000_000_000): for tp_size in [1, 2]:
try: if tp_size == 2 and torch.cuda.device_count() < 2:
torch.empty((size,), dtype=torch.uint8, device="cuda") continue
torch.cuda.empty_cache()
return True print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume")
except torch.cuda.OutOfMemoryError: engine = sgl.Engine(
return False model_path=model_name,
random_seed=42,
enable_memory_saver=True,
mem_fraction_static=0.85, # Higher memory pressure
tp_size=tp_size,
)
params = self._common_test_params()
self._test_initial_generation(
engine,
params["prompt"],
params["sampling_params"],
params["expect_output_before_update_weights"],
)
t = time.perf_counter()
gpu_memory_usage_before_release_kv_cache = get_gpu_memory_gb()
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE])
gpu_memory_usage_after_release_kv_cache = get_gpu_memory_gb()
self.assertLess(
gpu_memory_usage_after_release_kv_cache,
gpu_memory_usage_before_release_kv_cache,
)
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
gpu_memory_usage_after_release_weights = get_gpu_memory_gb()
self.assertLess(
gpu_memory_usage_after_release_weights,
gpu_memory_usage_after_release_kv_cache,
)
print(f"Release took {time.perf_counter() - t:.2f}s")
print(
f"Memory: {gpu_memory_usage_before_release_kv_cache:.1f}{gpu_memory_usage_after_release_kv_cache:.1f}{gpu_memory_usage_after_release_weights:.1f} GB"
)
if _DEBUG_EXTRA:
time.sleep(3)
t = time.perf_counter()
gpu_memory_usage_before_resume_weights = get_gpu_memory_gb()
# gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume_weights should be close
self.assertAlmostEqual(
gpu_memory_usage_after_release_weights,
gpu_memory_usage_before_resume_weights,
delta=3.0,
)
print(f"Resume weights took {time.perf_counter() - t:.2f}s")
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
gpu_memory_usage_after_resume_weights = get_gpu_memory_gb()
self.assertGreater(
gpu_memory_usage_after_resume_weights,
gpu_memory_usage_before_resume_weights,
)
# Update weights from a trained model to serving engine, and then destroy the trained model
hf_model_new = AutoModelForCausalLM.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
torch_dtype="bfloat16",
device_map="cuda",
)
gpu_memory_usage_after_loaded_hf_model = get_gpu_memory_gb()
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
# destroy the hf model
del hf_model_new
torch.cuda.empty_cache()
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE])
gpu_memory_usage_after_resume_kv_cache = get_gpu_memory_gb()
self.assertGreater(
gpu_memory_usage_after_resume_kv_cache,
gpu_memory_usage_after_resume_weights,
)
print(f"Resume + update took {time.perf_counter() - t:.2f}s")
print(
f"Memory: {gpu_memory_usage_before_resume_weights:.1f}{gpu_memory_usage_after_resume_weights:.1f}{gpu_memory_usage_after_loaded_hf_model:.1f}{gpu_memory_usage_after_resume_kv_cache:.1f} GB"
)
print("generate (#2)")
outputs = engine.generate(params["prompt"], params["sampling_params"])[
"text"
]
self.assertEqual(outputs, params["expect_output_after_update_weights"])
engine.shutdown()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -235,7 +235,8 @@ def _run_subprocess( ...@@ -235,7 +235,8 @@ def _run_subprocess(
output_writer.send(execution_ok) output_writer.send(execution_ok)
output_writer.close() output_writer.close()
engine.shutdown() if "engine" in locals() and engine is not None:
engine.shutdown()
print(f"subprocess[{rank=}] end", flush=True) print(f"subprocess[{rank=}] end", flush=True)
......
...@@ -249,7 +249,8 @@ def _run_subprocess( ...@@ -249,7 +249,8 @@ def _run_subprocess(
output_writer.send(execution_ok) output_writer.send(execution_ok)
output_writer.close() output_writer.close()
engine.shutdown() if "engine" in locals() and engine is not None:
engine.shutdown()
print(f"subprocess[{rank=}] end", flush=True) print(f"subprocess[{rank=}] end", flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment