Unverified Commit 923f5183 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)

parent d08c77c4
...@@ -44,6 +44,7 @@ srt_hpu = ["sglang[runtime_common]"] ...@@ -44,6 +44,7 @@ srt_hpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver"]
test = [ test = [
"jsonlines", "jsonlines",
"matplotlib", "matplotlib",
......
...@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). ...@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Union
import torch
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -459,6 +457,26 @@ class GetWeightsByNameReqOutput: ...@@ -459,6 +457,26 @@ class GetWeightsByNameReqOutput:
parameter: list parameter: list
@dataclass
class ReleaseMemoryOccupationReqInput:
pass
@dataclass
class ReleaseMemoryOccupationReqOutput:
pass
@dataclass
class ResumeMemoryOccupationReqInput:
pass
@dataclass
class ResumeMemoryOccupationReqOutput:
pass
@dataclass @dataclass
class AbortReq: class AbortReq:
# The request id # The request id
......
...@@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
...@@ -88,6 +92,7 @@ from sglang.srt.utils import ( ...@@ -88,6 +92,7 @@ from sglang.srt.utils import (
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
) )
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -357,6 +362,10 @@ class Scheduler: ...@@ -357,6 +362,10 @@ class Scheduler:
t.start() t.start()
self.parent_process = psutil.Process().parent() self.parent_process = psutil.Process().parent()
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
# Init profiler # Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None self.profiler = None
...@@ -519,6 +528,12 @@ class Scheduler: ...@@ -519,6 +528,12 @@ class Scheduler:
elif isinstance(recv_req, GetWeightsByNameReqInput): elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req) parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
self.release_memory_occupation()
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
self.resume_memory_occupation()
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
elif isinstance(recv_req, ProfileReq): elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE: if recv_req == ProfileReq.START_PROFILE:
self.start_profile() self.start_profile()
...@@ -1538,6 +1553,20 @@ class Scheduler: ...@@ -1538,6 +1553,20 @@ class Scheduler:
parameter = self.tp_worker.get_weights_by_name(recv_req) parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter return parameter
def release_memory_occupation(self):
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause()
self.flush_cache()
def resume_memory_occupation(self):
self.memory_saver_adapter.resume()
_import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
def start_profile(self) -> None: def start_profile(self) -> None:
if self.profiler is None: if self.profiler is None:
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")
...@@ -1576,6 +1605,20 @@ class Scheduler: ...@@ -1576,6 +1605,20 @@ class Scheduler:
del self.sessions[session_id] del self.sessions[session_id]
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
def run_scheduler_process( def run_scheduler_process(
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
......
...@@ -53,6 +53,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -53,6 +53,10 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SessionParams, SessionParams,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -188,6 +192,12 @@ class TokenizerManager: ...@@ -188,6 +192,12 @@ class TokenizerManager:
self.get_weights_by_name_communicator = _Communicator( self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.release_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
# Metrics # Metrics
if self.enable_metrics: if self.enable_metrics:
...@@ -548,6 +558,22 @@ class TokenizerManager: ...@@ -548,6 +558,22 @@ class TokenizerManager:
else: else:
return all_parameters return all_parameters
async def release_memory_occupation(
self,
obj: ReleaseMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.release_memory_occupation_communicator(obj)
async def resume_memory_occupation(
self,
obj: ResumeMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.resume_memory_occupation_communicator(obj)
async def open_session( async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
): ):
...@@ -627,6 +653,8 @@ class TokenizerManager: ...@@ -627,6 +653,8 @@ class TokenizerManager:
UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromDistributedReqOutput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj() ] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)): if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
...@@ -709,6 +737,10 @@ class TokenizerManager: ...@@ -709,6 +737,10 @@ class TokenizerManager:
self.update_weights_from_tensor_communicator.handle_recv(recv_obj) self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput): elif isinstance(recv_obj, GetWeightsByNameReqOutput):
self.get_weights_by_name_communicator.handle_recv(recv_obj) self.get_weights_by_name_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
self.release_memory_occupation_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
self.resume_memory_occupation_communicator.handle_recv(recv_obj)
else: else:
raise ValueError(f"Invalid object: {recv_obj=}") raise ValueError(f"Invalid object: {recv_obj=}")
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
""" """
Memory pool. Memory pool.
...@@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024 ...@@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024
class ReqToTokenPool: class ReqToTokenPool:
"""A memory pool that maps a request to its token locations.""" """A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int, device: str, use_records: bool): def __init__(
self,
size: int,
max_context_len: int,
device: str,
use_records: bool,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.size = size self.size = size
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.device = device self.device = device
self.req_to_token = torch.zeros( with memory_saver_adapter.region():
(size, max_context_len), dtype=torch.int32, device=device self.req_to_token = torch.zeros(
) (size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size)) self.free_slots = list(range(size))
self.write_records = [] self.write_records = []
self.use_records = use_records self.use_records = use_records
...@@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_dim: int, head_dim: int,
layer_num: int, layer_num: int,
device: str, device: str,
enable_memory_saver: bool,
): ):
super().__init__(size, dtype, device) super().__init__(size, dtype, device)
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.head_num = head_num self.head_num = head_num
self.head_dim = head_dim self.head_dim = head_dim
self.layer_num = layer_num self.layer_num = layer_num
...@@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
) )
def _create_buffers(self): def _create_buffers(self):
# [size, head_num, head_dim] for each layer with self.memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens. # [size, head_num, head_dim] for each layer
self.k_buffer = [ # The padded slot 0 is used for writing dummy outputs from padded tokens.
torch.empty( self.k_buffer = [
(self.size + 1, self.head_num, self.head_dim), torch.empty(
dtype=self.store_dtype, (self.size + 1, self.head_num, self.head_dim),
device=self.device, dtype=self.store_dtype,
) device=self.device,
for _ in range(self.layer_num) )
] for _ in range(self.layer_num)
self.v_buffer = [ ]
torch.empty( self.v_buffer = [
(self.size + 1, self.head_num, self.head_dim), torch.empty(
dtype=self.store_dtype, (self.size + 1, self.head_num, self.head_dim),
device=self.device, dtype=self.store_dtype,
) device=self.device,
for _ in range(self.layer_num) )
] for _ in range(self.layer_num)
]
def _clear_buffers(self): def _clear_buffers(self):
del self.k_buffer del self.k_buffer
...@@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool): ...@@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
qk_rope_head_dim: int, qk_rope_head_dim: int,
layer_num: int, layer_num: int,
device: str, device: str,
enable_memory_saver: bool,
): ):
super().__init__(size, dtype, device) super().__init__(size, dtype, device)
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [ memory_saver_adapter = TorchMemorySaverAdapter.create(
torch.empty( enable=enable_memory_saver
(size + 1, 1, kv_lora_rank + qk_rope_head_dim), )
dtype=self.store_dtype,
device=device, with memory_saver_adapter.region():
) # The padded slot 0 is used for writing dummy outputs from padded tokens.
for _ in range(layer_num) self.kv_buffer = [
] torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
...@@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): ...@@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
layer_num: int, layer_num: int,
device: str, device: str,
heavy_channel_num: int, heavy_channel_num: int,
enable_memory_saver: bool,
): ):
super().__init__(size, dtype, device) super().__init__(size, dtype, device)
# [size, head_num, head_dim] for each layer memory_saver_adapter = TorchMemorySaverAdapter.create(
self.k_buffer = [ enable=enable_memory_saver
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) )
for _ in range(layer_num)
] with memory_saver_adapter.region():
self.v_buffer = [ # [size, head_num, head_dim] for each layer
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) self.k_buffer = [
for _ in range(layer_num) torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
] for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer self.v_buffer = [
self.label_buffer = [ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
torch.empty( for _ in range(layer_num)
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device ]
)
for _ in range(layer_num) # [size, head_num, heavy_channel_num] for each layer
] self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id] return self.k_buffer[layer_id]
......
...@@ -60,6 +60,7 @@ from sglang.srt.utils import ( ...@@ -60,6 +60,7 @@ from sglang.srt.utils import (
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
) )
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -166,6 +167,10 @@ class ModelRunner: ...@@ -166,6 +167,10 @@ class ModelRunner:
# Get memory before model loading # Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
)
# Load the model # Load the model
self.sampler = Sampler() self.sampler = Sampler()
self.load_model() self.load_model()
...@@ -272,11 +277,12 @@ class ModelRunner: ...@@ -272,11 +277,12 @@ class ModelRunner:
monkey_patch_vllm_gguf_config() monkey_patch_vllm_gguf_config()
# Load the model # Load the model
self.model = get_model( with self.memory_saver_adapter.region():
model_config=self.model_config, self.model = get_model(
load_config=self.load_config, model_config=self.model_config,
device_config=DeviceConfig(self.device), load_config=self.load_config,
) device_config=DeviceConfig(self.device),
)
if self.server_args.kv_cache_dtype == "fp8_e4m3": if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None: if self.server_args.quantization_param_path is not None:
...@@ -417,7 +423,7 @@ class ModelRunner: ...@@ -417,7 +423,7 @@ class ModelRunner:
logger.info( logger.info(
f"init custom process group: master_address={master_address}, master_port={master_port}, " f"init custom process group: master_address={master_address}, master_port={master_port}, "
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}" f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
) )
try: try:
...@@ -590,6 +596,7 @@ class ModelRunner: ...@@ -590,6 +596,7 @@ class ModelRunner:
max_context_len=self.model_config.context_len + 4, max_context_len=self.model_config.context_len + 4,
device=self.device, device=self.device,
use_records=False, use_records=False,
enable_memory_saver=self.server_args.enable_memory_saver,
) )
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
...@@ -602,6 +609,7 @@ class ModelRunner: ...@@ -602,6 +609,7 @@ class ModelRunner:
qk_rope_head_dim=self.model_config.qk_rope_head_dim, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
) )
elif self.server_args.enable_double_sparsity: elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.token_to_kv_pool = DoubleSparseTokenToKVPool(
...@@ -612,6 +620,7 @@ class ModelRunner: ...@@ -612,6 +620,7 @@ class ModelRunner:
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=self.device, device=self.device,
heavy_channel_num=self.server_args.ds_heavy_channel_num, heavy_channel_num=self.server_args.ds_heavy_channel_num,
enable_memory_saver=self.server_args.enable_memory_saver,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
...@@ -621,6 +630,7 @@ class ModelRunner: ...@@ -621,6 +630,7 @@ class ModelRunner:
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
) )
logger.info( logger.info(
f"Memory pool end. " f"Memory pool end. "
......
...@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union ...@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
import torch import torch
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -57,6 +59,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -57,6 +59,8 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
OpenSessionReqInput, OpenSessionReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
...@@ -255,6 +259,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): ...@@ -255,6 +259,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
async def release_memory_occupation(
obj: ReleaseMemoryOccupationReqInput, request: Request
):
"""Release GPU occupation temporarily"""
try:
await tokenizer_manager.release_memory_occupation(obj, request)
except Exception as e:
return _create_error_response(e)
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
async def resume_memory_occupation(
obj: ResumeMemoryOccupationReqInput, request: Request
):
"""Resume GPU occupation"""
try:
await tokenizer_manager.resume_memory_occupation(obj, request)
except Exception as e:
return _create_error_response(e)
@app.api_route("/open_session", methods=["GET", "POST"]) @app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request): async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id.""" """Open a session, and return its unique session id."""
...@@ -438,6 +464,10 @@ def launch_engine( ...@@ -438,6 +464,10 @@ def launch_engine(
server_args.model_path, server_args.tokenizer_path server_args.model_path, server_args.tokenizer_path
) )
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
if server_args.dp_size == 1: if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes # Launch tensor parallel scheduler processes
scheduler_procs = [] scheduler_procs = []
...@@ -454,7 +484,8 @@ def launch_engine( ...@@ -454,7 +484,8 @@ def launch_engine(
target=run_scheduler_process, target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer), args=(server_args, port_args, gpu_id, tp_rank, None, writer),
) )
proc.start() with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc) scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader) scheduler_pipe_readers.append(reader)
...@@ -471,7 +502,8 @@ def launch_engine( ...@@ -471,7 +502,8 @@ def launch_engine(
target=run_data_parallel_controller_process, target=run_data_parallel_controller_process,
args=(server_args, port_args, writer), args=(server_args, port_args, writer),
) )
proc.start() with memory_saver_adapter.configure_subprocess():
proc.start()
# Launch detokenizer process # Launch detokenizer process
detoken_proc = mp.Process( detoken_proc = mp.Process(
...@@ -897,6 +929,18 @@ class Engine: ...@@ -897,6 +929,18 @@ class Engine:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None)) return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
def release_memory_occupation(self):
"""Release GPU occupation temporarily"""
obj = ReleaseMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
def resume_memory_occupation(self):
"""Resume GPU occupation"""
obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
class Runtime: class Runtime:
""" """
......
...@@ -23,7 +23,6 @@ from typing import List, Optional ...@@ -23,7 +23,6 @@ from typing import List, Optional
import torch import torch
from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
get_amdgpu_memory_capacity, get_amdgpu_memory_capacity,
get_hpu_memory_capacity, get_hpu_memory_capacity,
...@@ -157,6 +156,7 @@ class ServerArgs: ...@@ -157,6 +156,7 @@ class ServerArgs:
triton_attention_num_kv_splits: int = 8 triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1 num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
...@@ -854,6 +854,11 @@ class ServerArgs: ...@@ -854,6 +854,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Delete the model checkpoint after loading the model.", help="Delete the model checkpoint after loading the model.",
) )
parser.add_argument(
"--enable-memory-saver",
action="store_true",
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
from abc import ABC
from contextlib import contextmanager
try:
import torch_memory_saver
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
except ImportError:
pass
class TorchMemorySaverAdapter(ABC):
@staticmethod
def create(enable: bool):
return (
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
)
def configure_subprocess(self):
raise NotImplementedError
def region(self):
raise NotImplementedError
def pause(self):
raise NotImplementedError
def resume(self):
raise NotImplementedError
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
def configure_subprocess(self):
return torch_memory_saver.configure_subprocess()
def region(self):
return _primary_memory_saver.region()
def pause(self):
return _primary_memory_saver.pause()
def resume(self):
return _primary_memory_saver.resume()
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
@contextmanager
def configure_subprocess(self):
yield
@contextmanager
def region(self):
yield
def pause(self):
pass
def resume(self):
pass
...@@ -12,8 +12,9 @@ bash "${SCRIPT_DIR}/killall_sglang.sh" ...@@ -12,8 +12,9 @@ bash "${SCRIPT_DIR}/killall_sglang.sh"
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
# Force reinstall flashinfer # Force reinstall flashinfer and torch_memory_saver
pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
pip install torch_memory_saver --force-reinstall
pip install transformers==4.45.2 sentence_transformers accelerate peft pip install transformers==4.45.2 sentence_transformers accelerate peft
......
...@@ -29,6 +29,7 @@ suites = { ...@@ -29,6 +29,7 @@ suites = {
"test_openai_server.py", "test_openai_server.py",
"test_pytorch_sampling_backend.py", "test_pytorch_sampling_backend.py",
"test_radix_attention.py", "test_radix_attention.py",
"test_release_memory_occupation.py",
"test_retract_decode.py", "test_retract_decode.py",
"test_server_args.py", "test_server_args.py",
"test_session_control.py", "test_session_control.py",
......
import time
import unittest
import torch
from transformers import AutoModelForCausalLM
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
_DEBUG_EXTRA = True
class TestReleaseMemoryOccupation(unittest.TestCase):
def test_release_and_resume_occupation(self):
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
expect_output = " to spend it outdoors. I decided to"
engine = sgl.Engine(
model_path=model_name,
random_seed=42,
enable_memory_saver=True,
# disable_cuda_graph=True, # for debugging only
)
hf_model_new = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="bfloat16"
)
print("generate (#1)")
outputs = engine.generate(prompt, sampling_params)["text"]
self.assertEqual(outputs, expect_output)
if _DEBUG_EXTRA:
time.sleep(3)
self.assertEqual(
_try_allocate_big_tensor(),
False,
"Should not be able to allocate big tensors before releasing",
)
print("release_memory_occupation start")
t = time.time()
engine.release_memory_occupation()
if _DEBUG_EXTRA:
print("release_memory_occupation", time.time() - t)
if _DEBUG_EXTRA:
time.sleep(5)
self.assertEqual(
_try_allocate_big_tensor(),
True,
"Should be able to allocate big tensors aftre releasing",
)
if _DEBUG_EXTRA:
time.sleep(5)
print("resume_memory_occupation start")
t = time.time()
engine.resume_memory_occupation()
if _DEBUG_EXTRA:
print("resume_memory_occupation", time.time() - t)
self.assertEqual(
_try_allocate_big_tensor(),
False,
"Should not be able to allocate big tensors after resuming",
)
print("update_weights_from_tensor")
# As if: PPO has updated hf model's weights, and now we sync it to SGLang
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
print("generate (#2)")
outputs = engine.generate(prompt, sampling_params)["text"]
self.assertEqual(outputs, expect_output)
if _DEBUG_EXTRA:
time.sleep(4)
engine.shutdown()
def _try_allocate_big_tensor(size: int = 20_000_000_000):
try:
torch.empty((size,), dtype=torch.uint8, device="cuda")
torch.cuda.empty_cache()
return True
except torch.cuda.OutOfMemoryError:
return False
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment