"vscode:/vscode.git/clone" did not exist on "c009512a93704aaa02db2877b65cc8e661b2824c"
Unverified Commit 923f5183 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)

parent d08c77c4
......@@ -44,6 +44,7 @@ srt_hpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver"]
test = [
"jsonlines",
"matplotlib",
......
......@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import uuid
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union
import torch
from typing import Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
......@@ -459,6 +457,26 @@ class GetWeightsByNameReqOutput:
parameter: list
@dataclass
class ReleaseMemoryOccupationReqInput:
pass
@dataclass
class ReleaseMemoryOccupationReqOutput:
pass
@dataclass
class ResumeMemoryOccupationReqInput:
pass
@dataclass
class ResumeMemoryOccupationReqOutput:
pass
@dataclass
class AbortReq:
# The request id
......
......@@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
......@@ -88,6 +92,7 @@ from sglang.srt.utils import (
set_random_seed,
suppress_other_loggers,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
......@@ -357,6 +362,10 @@ class Scheduler:
t.start()
self.parent_process = psutil.Process().parent()
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
......@@ -519,6 +528,12 @@ class Scheduler:
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
self.release_memory_occupation()
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
self.resume_memory_occupation()
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
......@@ -1538,6 +1553,20 @@ class Scheduler:
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter
def release_memory_occupation(self):
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause()
self.flush_cache()
def resume_memory_occupation(self):
self.memory_saver_adapter.resume()
_import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
......@@ -1576,6 +1605,20 @@ class Scheduler:
del self.sessions[session_id]
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
def run_scheduler_process(
server_args: ServerArgs,
port_args: PortArgs,
......
......@@ -53,6 +53,10 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SessionParams,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
......@@ -188,6 +192,12 @@ class TokenizerManager:
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.release_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
# Metrics
if self.enable_metrics:
......@@ -548,6 +558,22 @@ class TokenizerManager:
else:
return all_parameters
async def release_memory_occupation(
self,
obj: ReleaseMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.release_memory_occupation_communicator(obj)
async def resume_memory_occupation(
self,
obj: ResumeMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.resume_memory_occupation_communicator(obj)
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
......@@ -627,6 +653,8 @@ class TokenizerManager:
UpdateWeightsFromDistributedReqOutput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqOutput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
......@@ -709,6 +737,10 @@ class TokenizerManager:
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
self.get_weights_by_name_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
self.release_memory_occupation_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
self.resume_memory_occupation_communicator.handle_recv(recv_obj)
else:
raise ValueError(f"Invalid object: {recv_obj=}")
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
Memory pool.
......@@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
def __init__(
self,
size: int,
max_context_len: int,
device: str,
use_records: bool,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.size = size
self.max_context_len = max_context_len
self.device = device
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
with memory_saver_adapter.region():
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
self.write_records = []
self.use_records = use_records
......@@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_dim: int,
layer_num: int,
device: str,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
......@@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
)
def _create_buffers(self):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
def _clear_buffers(self):
del self.k_buffer
......@@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
qk_rope_head_dim: int,
layer_num: int,
device: str,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
with memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype:
......@@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
layer_num: int,
device: str,
heavy_channel_num: int,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
with memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
......
......@@ -60,6 +60,7 @@ from sglang.srt.utils import (
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__)
......@@ -166,6 +167,10 @@ class ModelRunner:
# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.server_args.enable_memory_saver
)
# Load the model
self.sampler = Sampler()
self.load_model()
......@@ -272,11 +277,12 @@ class ModelRunner:
monkey_patch_vllm_gguf_config()
# Load the model
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
with self.memory_saver_adapter.region():
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
......@@ -417,7 +423,7 @@ class ModelRunner:
logger.info(
f"init custom process group: master_address={master_address}, master_port={master_port}, "
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
)
try:
......@@ -590,6 +596,7 @@ class ModelRunner:
max_context_len=self.model_config.context_len + 4,
device=self.device,
use_records=False,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
......@@ -602,6 +609,7 @@ class ModelRunner:
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
......@@ -612,6 +620,7 @@ class ModelRunner:
layer_num=self.model_config.num_hidden_layers,
device=self.device,
heavy_channel_num=self.server_args.ds_heavy_channel_num,
enable_memory_saver=self.server_args.enable_memory_saver,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
......@@ -621,6 +630,7 @@ class ModelRunner:
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
logger.info(
f"Memory pool end. "
......
......@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
import torch
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -57,6 +59,8 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
OpenSessionReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
......@@ -255,6 +259,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
return _create_error_response(e)
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
async def release_memory_occupation(
obj: ReleaseMemoryOccupationReqInput, request: Request
):
"""Release GPU occupation temporarily"""
try:
await tokenizer_manager.release_memory_occupation(obj, request)
except Exception as e:
return _create_error_response(e)
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
async def resume_memory_occupation(
obj: ResumeMemoryOccupationReqInput, request: Request
):
"""Resume GPU occupation"""
try:
await tokenizer_manager.resume_memory_occupation(obj, request)
except Exception as e:
return _create_error_response(e)
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
......@@ -438,6 +464,10 @@ def launch_engine(
server_args.model_path, server_args.tokenizer_path
)
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
scheduler_procs = []
......@@ -454,7 +484,8 @@ def launch_engine(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
)
proc.start()
with memory_saver_adapter.configure_subprocess():
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
......@@ -471,7 +502,8 @@ def launch_engine(
target=run_data_parallel_controller_process,
args=(server_args, port_args, writer),
)
proc.start()
with memory_saver_adapter.configure_subprocess():
proc.start()
# Launch detokenizer process
detoken_proc = mp.Process(
......@@ -897,6 +929,18 @@ class Engine:
loop = asyncio.get_event_loop()
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
def release_memory_occupation(self):
"""Release GPU occupation temporarily"""
obj = ReleaseMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
def resume_memory_occupation(self):
"""Resume GPU occupation"""
obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
class Runtime:
"""
......
......@@ -23,7 +23,6 @@ from typing import List, Optional
import torch
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
......@@ -157,6 +156,7 @@ class ServerArgs:
triton_attention_num_kv_splits: int = 8
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False
def __post_init__(self):
# Set missing default values
......@@ -854,6 +854,11 @@ class ServerArgs:
action="store_true",
help="Delete the model checkpoint after loading the model.",
)
parser.add_argument(
"--enable-memory-saver",
action="store_true",
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
from abc import ABC
from contextlib import contextmanager
try:
import torch_memory_saver
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
except ImportError:
pass
class TorchMemorySaverAdapter(ABC):
@staticmethod
def create(enable: bool):
return (
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
)
def configure_subprocess(self):
raise NotImplementedError
def region(self):
raise NotImplementedError
def pause(self):
raise NotImplementedError
def resume(self):
raise NotImplementedError
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
def configure_subprocess(self):
return torch_memory_saver.configure_subprocess()
def region(self):
return _primary_memory_saver.region()
def pause(self):
return _primary_memory_saver.pause()
def resume(self):
return _primary_memory_saver.resume()
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
@contextmanager
def configure_subprocess(self):
yield
@contextmanager
def region(self):
yield
def pause(self):
pass
def resume(self):
pass
......@@ -12,8 +12,9 @@ bash "${SCRIPT_DIR}/killall_sglang.sh"
pip install --upgrade pip
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
# Force reinstall flashinfer
# Force reinstall flashinfer and torch_memory_saver
pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
pip install torch_memory_saver --force-reinstall
pip install transformers==4.45.2 sentence_transformers accelerate peft
......
......@@ -29,6 +29,7 @@ suites = {
"test_openai_server.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_release_memory_occupation.py",
"test_retract_decode.py",
"test_server_args.py",
"test_session_control.py",
......
import time
import unittest
import torch
from transformers import AutoModelForCausalLM
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
_DEBUG_EXTRA = True
class TestReleaseMemoryOccupation(unittest.TestCase):
def test_release_and_resume_occupation(self):
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
expect_output = " to spend it outdoors. I decided to"
engine = sgl.Engine(
model_path=model_name,
random_seed=42,
enable_memory_saver=True,
# disable_cuda_graph=True, # for debugging only
)
hf_model_new = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="bfloat16"
)
print("generate (#1)")
outputs = engine.generate(prompt, sampling_params)["text"]
self.assertEqual(outputs, expect_output)
if _DEBUG_EXTRA:
time.sleep(3)
self.assertEqual(
_try_allocate_big_tensor(),
False,
"Should not be able to allocate big tensors before releasing",
)
print("release_memory_occupation start")
t = time.time()
engine.release_memory_occupation()
if _DEBUG_EXTRA:
print("release_memory_occupation", time.time() - t)
if _DEBUG_EXTRA:
time.sleep(5)
self.assertEqual(
_try_allocate_big_tensor(),
True,
"Should be able to allocate big tensors aftre releasing",
)
if _DEBUG_EXTRA:
time.sleep(5)
print("resume_memory_occupation start")
t = time.time()
engine.resume_memory_occupation()
if _DEBUG_EXTRA:
print("resume_memory_occupation", time.time() - t)
self.assertEqual(
_try_allocate_big_tensor(),
False,
"Should not be able to allocate big tensors after resuming",
)
print("update_weights_from_tensor")
# As if: PPO has updated hf model's weights, and now we sync it to SGLang
engine.update_weights_from_tensor(list(hf_model_new.named_parameters()))
print("generate (#2)")
outputs = engine.generate(prompt, sampling_params)["text"]
self.assertEqual(outputs, expect_output)
if _DEBUG_EXTRA:
time.sleep(4)
engine.shutdown()
def _try_allocate_big_tensor(size: int = 20_000_000_000):
try:
torch.empty((size,), dtype=torch.uint8, device="cuda")
torch.cuda.empty_cache()
return True
except torch.cuda.OutOfMemoryError:
return False
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment