Unverified Commit 691c8534 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support releasing CUDA graph memory when paused (#7873)


Co-authored-by: default avatarryang-max <y1cunhui.yang@gmail.com>
Co-authored-by: default avatarryang <38470282+ryang-max@users.noreply.github.com>
parent d2b8c412
# GPU Memory Types # GPU Memory Types
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache" GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
GPU_MEMORY_TYPE_WEIGHTS = "weights" GPU_MEMORY_TYPE_WEIGHTS = "weights"
GPU_MEMORY_TYPE_CUDA_GRAPH = "cuda_graph"
GPU_MEMORY_ALL_TYPES = [
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
GPU_MEMORY_TYPE_CUDA_GRAPH,
]
...@@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ...@@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
is_weak_contiguous, is_weak_contiguous,
) )
from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.environ import envs
from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0 from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -210,6 +211,7 @@ class CustomAllreduce: ...@@ -210,6 +211,7 @@ class CustomAllreduce:
self.register_buffer(self.buffer) self.register_buffer(self.buffer)
self.disabled = False self.disabled = False
self.tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get()
@staticmethod @staticmethod
def create_shared_buffer( def create_shared_buffer(
...@@ -394,7 +396,7 @@ class CustomAllreduce: ...@@ -394,7 +396,7 @@ class CustomAllreduce:
if _is_hip: if _is_hip:
return self.all_reduce_reg(input) return self.all_reduce_reg(input)
else: else:
return self.all_reduce(input, registered=True) return self.all_reduce(input, registered=not self.tms_cudagraph)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
......
...@@ -239,6 +239,9 @@ class Envs: ...@@ -239,6 +239,9 @@ class Envs:
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28) SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
SGLANG_RESIZE_RESAMPLE = EnvStr("") SGLANG_RESIZE_RESAMPLE = EnvStr("")
# Release & Resume Memory
SGLANG_MEMORY_SAVER_CUDA_GRAPH = EnvBool(False)
# Ktransformers # Ktransformers
SGLANG_KT_MOE_NUM_GPU_EXPERTS = EnvInt(None) SGLANG_KT_MOE_NUM_GPU_EXPERTS = EnvInt(None)
SGLANG_KT_MOE_CPUINFER = EnvInt(None) SGLANG_KT_MOE_CPUINFER = EnvInt(None)
......
...@@ -5,7 +5,12 @@ from typing import TYPE_CHECKING, Tuple ...@@ -5,7 +5,12 @@ from typing import TYPE_CHECKING, Tuple
import torch import torch
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constants import (
GPU_MEMORY_ALL_TYPES,
GPU_MEMORY_TYPE_CUDA_GRAPH,
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
)
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput, DestroyWeightsUpdateGroupReqOutput,
...@@ -104,7 +109,7 @@ class SchedulerUpdateWeightsMixin: ...@@ -104,7 +109,7 @@ class SchedulerUpdateWeightsMixin:
tags = recv_req.tags tags = recv_req.tags
if tags is None or len(tags) == 0: if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE] tags = GPU_MEMORY_ALL_TYPES
for tag in tags: for tag in tags:
self.offload_tags.add(tag) self.offload_tags.add(tag)
...@@ -120,6 +125,9 @@ class SchedulerUpdateWeightsMixin: ...@@ -120,6 +125,9 @@ class SchedulerUpdateWeightsMixin:
torch.distributed.barrier(self.tp_cpu_group) torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS) self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
if GPU_MEMORY_TYPE_CUDA_GRAPH in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_CUDA_GRAPH)
return ReleaseMemoryOccupationReqOutput() return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation( def resume_memory_occupation(
...@@ -128,11 +136,14 @@ class SchedulerUpdateWeightsMixin: ...@@ -128,11 +136,14 @@ class SchedulerUpdateWeightsMixin:
tags = recv_req.tags tags = recv_req.tags
if tags is None or len(tags) == 0: if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE] tags = GPU_MEMORY_ALL_TYPES
for tag in tags: for tag in tags:
self.offload_tags.remove(tag) self.offload_tags.remove(tag)
if GPU_MEMORY_TYPE_CUDA_GRAPH in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_CUDA_GRAPH)
if GPU_MEMORY_TYPE_WEIGHTS in tags: if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group) torch.distributed.barrier(self.tp_cpu_group)
......
...@@ -21,12 +21,14 @@ import inspect ...@@ -21,12 +21,14 @@ import inspect
import logging import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
import torch import torch
import tqdm import tqdm
from torch.profiler import ProfilerActivity, profile from torch.profiler import ProfilerActivity, profile
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.device_communicators.pynccl_allocator import ( from sglang.srt.distributed.device_communicators.pynccl_allocator import (
...@@ -64,6 +66,7 @@ from sglang.srt.utils import ( ...@@ -64,6 +66,7 @@ from sglang.srt.utils import (
require_mlp_tp_gather, require_mlp_tp_gather,
) )
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
try: try:
from kt_kernel import AMXMoEWrapper from kt_kernel import AMXMoEWrapper
...@@ -518,7 +521,16 @@ class CudaGraphRunner: ...@@ -518,7 +521,16 @@ class CudaGraphRunner:
logger.info(log_message) logger.info(log_message)
def _capture_graph(self, graph, pool, stream, run_once_fn): def _capture_graph(self, graph, pool, stream, run_once_fn):
with self.device_module.graph(graph, pool=pool, stream=stream): memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.model_runner.server_args.enable_memory_saver
and get_bool_env_var("SGLANG_MEMORY_SAVER_CUDA_GRAPH")
)
graph_fn = (
partial(memory_saver_adapter.cuda_graph, tag=GPU_MEMORY_TYPE_CUDA_GRAPH)
if memory_saver_adapter.enabled
else self.device_module.graph
)
with graph_fn(cuda_graph=graph, pool=pool, stream=stream):
out = run_once_fn() out = run_once_fn()
return out return out
......
...@@ -41,6 +41,12 @@ class TorchMemorySaverAdapter(ABC): ...@@ -41,6 +41,12 @@ class TorchMemorySaverAdapter(ABC):
def region(self, tag: str, enable_cpu_backup: bool = False): def region(self, tag: str, enable_cpu_backup: bool = False):
raise NotImplementedError raise NotImplementedError
def cuda_graph(self, **kwargs):
raise NotImplementedError
def disable(self):
raise NotImplementedError
def pause(self, tag: str): def pause(self, tag: str):
raise NotImplementedError raise NotImplementedError
...@@ -61,6 +67,12 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): ...@@ -61,6 +67,12 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
def region(self, tag: str, enable_cpu_backup: bool = False): def region(self, tag: str, enable_cpu_backup: bool = False):
return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup) return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
def cuda_graph(self, **kwargs):
return _memory_saver.cuda_graph(**kwargs)
def disable(self):
return _memory_saver.disable()
def pause(self, tag: str): def pause(self, tag: str):
return _memory_saver.pause(tag=tag) return _memory_saver.pause(tag=tag)
...@@ -81,6 +93,14 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): ...@@ -81,6 +93,14 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
def region(self, tag: str, enable_cpu_backup: bool = False): def region(self, tag: str, enable_cpu_backup: bool = False):
yield yield
@contextmanager
def cuda_graph(self, **kwargs):
yield
@contextmanager
def disable(self):
yield
def pause(self, tag: str): def pause(self, tag: str):
pass pass
......
...@@ -25,6 +25,7 @@ configurations (tp=1, tp=2) to ensure proper memory management in distributed se ...@@ -25,6 +25,7 @@ configurations (tp=1, tp=2) to ensure proper memory management in distributed se
data parallel size, we test it in verl. data parallel size, we test it in verl.
""" """
import os
import time import time
import unittest import unittest
...@@ -32,7 +33,11 @@ import torch ...@@ -32,7 +33,11 @@ import torch
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
import sglang as sgl import sglang as sgl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constants import (
GPU_MEMORY_TYPE_CUDA_GRAPH,
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
)
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
...@@ -59,6 +64,8 @@ class TestReleaseMemoryOccupation(CustomTestCase): ...@@ -59,6 +64,8 @@ class TestReleaseMemoryOccupation(CustomTestCase):
enable_weights_cpu_backup=False, enable_weights_cpu_backup=False,
): ):
"""Common setup for engine and HF model.""" """Common setup for engine and HF model."""
os.environ["SGLANG_MEMORY_SAVER_CUDA_GRAPH"] = "1"
engine = sgl.Engine( engine = sgl.Engine(
model_path=model_name, model_path=model_name,
random_seed=42, random_seed=42,
...@@ -215,6 +222,7 @@ class TestReleaseMemoryOccupation(CustomTestCase): ...@@ -215,6 +222,7 @@ class TestReleaseMemoryOccupation(CustomTestCase):
continue continue
print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume") print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume")
os.environ["SGLANG_MEMORY_SAVER_CUDA_GRAPH"] = "1"
engine = sgl.Engine( engine = sgl.Engine(
model_path=model_name, model_path=model_name,
random_seed=42, random_seed=42,
...@@ -232,17 +240,17 @@ class TestReleaseMemoryOccupation(CustomTestCase): ...@@ -232,17 +240,17 @@ class TestReleaseMemoryOccupation(CustomTestCase):
) )
t = time.perf_counter() t = time.perf_counter()
gpu_memory_usage_before_release_kv_cache = get_gpu_memory_gb() gpu_memory_usage_before_release = get_gpu_memory_gb()
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE]) engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE])
gpu_memory_usage_after_release_kv_cache = get_gpu_memory_gb() gpu_memory_usage_after_release_kv_cache = get_gpu_memory_gb()
self.assertLess( self.assertLess(
gpu_memory_usage_after_release_kv_cache, gpu_memory_usage_after_release_kv_cache,
gpu_memory_usage_before_release_kv_cache, gpu_memory_usage_before_release,
) )
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
gpu_memory_usage_after_release_weights = get_gpu_memory_gb() gpu_memory_usage_after_release_weights = get_gpu_memory_gb()
self.assertLess( self.assertLess(
...@@ -250,32 +258,48 @@ class TestReleaseMemoryOccupation(CustomTestCase): ...@@ -250,32 +258,48 @@ class TestReleaseMemoryOccupation(CustomTestCase):
gpu_memory_usage_after_release_kv_cache, gpu_memory_usage_after_release_kv_cache,
) )
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH])
gpu_memory_usage_after_release_cuda_graph = get_gpu_memory_gb()
self.assertLess(
gpu_memory_usage_after_release_cuda_graph,
gpu_memory_usage_after_release_weights,
)
print(f"Release took {time.perf_counter() - t:.2f}s") print(f"Release took {time.perf_counter() - t:.2f}s")
print( print(
f"Memory: {gpu_memory_usage_before_release_kv_cache:.1f}{gpu_memory_usage_after_release_kv_cache:.1f}{gpu_memory_usage_after_release_weights:.1f} GB" f"Memory: {gpu_memory_usage_before_release:.1f}{gpu_memory_usage_after_release_kv_cache:.1f}{gpu_memory_usage_after_release_weights:.1f} {gpu_memory_usage_after_release_cuda_graph:.1f} GB"
) )
if _DEBUG_EXTRA: if _DEBUG_EXTRA:
time.sleep(3) time.sleep(3)
t = time.perf_counter() t = time.perf_counter()
gpu_memory_usage_before_resume_weights = get_gpu_memory_gb() gpu_memory_usage_before_resume = get_gpu_memory_gb()
# gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume_weights should be close # gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume should be close
self.assertAlmostEqual( self.assertAlmostEqual(
gpu_memory_usage_after_release_weights, gpu_memory_usage_after_release_weights,
gpu_memory_usage_before_resume_weights, gpu_memory_usage_before_resume,
delta=3.0, delta=3.0,
) )
print(f"Resume weights took {time.perf_counter() - t:.2f}s") print(f"Resume weights took {time.perf_counter() - t:.2f}s")
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH])
gpu_memory_usage_after_resume_cuda_graph = get_gpu_memory_gb()
self.assertGreater(
gpu_memory_usage_after_resume_cuda_graph,
gpu_memory_usage_before_resume,
)
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS]) engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
gpu_memory_usage_after_resume_weights = get_gpu_memory_gb() gpu_memory_usage_after_resume_weights = get_gpu_memory_gb()
self.assertGreater( self.assertGreater(
gpu_memory_usage_after_resume_weights, gpu_memory_usage_after_resume_weights,
gpu_memory_usage_before_resume_weights, gpu_memory_usage_after_resume_cuda_graph,
) )
# Update weights from a trained model to serving engine, and then destroy the trained model # Update weights from a trained model to serving engine, and then destroy the trained model
...@@ -300,7 +324,7 @@ class TestReleaseMemoryOccupation(CustomTestCase): ...@@ -300,7 +324,7 @@ class TestReleaseMemoryOccupation(CustomTestCase):
print(f"Resume + update took {time.perf_counter() - t:.2f}s") print(f"Resume + update took {time.perf_counter() - t:.2f}s")
print( print(
f"Memory: {gpu_memory_usage_before_resume_weights:.1f}{gpu_memory_usage_after_resume_weights:.1f}{gpu_memory_usage_after_loaded_hf_model:.1f}{gpu_memory_usage_after_resume_kv_cache:.1f} GB" f"Memory: {gpu_memory_usage_before_resume:.1f}{gpu_memory_usage_after_resume_cuda_graph:.1f}{gpu_memory_usage_after_resume_weights:.1f}{gpu_memory_usage_after_loaded_hf_model:.1f}{gpu_memory_usage_after_resume_kv_cache:.1f} GB"
) )
print("generate (#2)") print("generate (#2)")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment