Unverified Commit 691c8534 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support releasing CUDA graph memory when paused (#7873)


Co-authored-by: default avatarryang-max <y1cunhui.yang@gmail.com>
Co-authored-by: default avatarryang <38470282+ryang-max@users.noreply.github.com>
parent d2b8c412
# GPU Memory Types
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
GPU_MEMORY_TYPE_WEIGHTS = "weights"
GPU_MEMORY_TYPE_CUDA_GRAPH = "cuda_graph"
GPU_MEMORY_ALL_TYPES = [
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
GPU_MEMORY_TYPE_CUDA_GRAPH,
]
......@@ -18,6 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
is_weak_contiguous,
)
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.environ import envs
from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
logger = logging.getLogger(__name__)
......@@ -210,6 +211,7 @@ class CustomAllreduce:
self.register_buffer(self.buffer)
self.disabled = False
self.tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get()
@staticmethod
def create_shared_buffer(
......@@ -394,7 +396,7 @@ class CustomAllreduce:
if _is_hip:
return self.all_reduce_reg(input)
else:
return self.all_reduce(input, registered=True)
return self.all_reduce(input, registered=not self.tms_cudagraph)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
......
......@@ -239,6 +239,9 @@ class Envs:
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
SGLANG_RESIZE_RESAMPLE = EnvStr("")
# Release & Resume Memory
SGLANG_MEMORY_SAVER_CUDA_GRAPH = EnvBool(False)
# Ktransformers
SGLANG_KT_MOE_NUM_GPU_EXPERTS = EnvInt(None)
SGLANG_KT_MOE_CPUINFER = EnvInt(None)
......
......@@ -5,7 +5,12 @@ from typing import TYPE_CHECKING, Tuple
import torch
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constants import (
GPU_MEMORY_ALL_TYPES,
GPU_MEMORY_TYPE_CUDA_GRAPH,
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
)
from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput,
......@@ -104,7 +109,7 @@ class SchedulerUpdateWeightsMixin:
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
tags = GPU_MEMORY_ALL_TYPES
for tag in tags:
self.offload_tags.add(tag)
......@@ -120,6 +125,9 @@ class SchedulerUpdateWeightsMixin:
torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
if GPU_MEMORY_TYPE_CUDA_GRAPH in tags:
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_CUDA_GRAPH)
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(
......@@ -128,11 +136,14 @@ class SchedulerUpdateWeightsMixin:
tags = recv_req.tags
if tags is None or len(tags) == 0:
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
tags = GPU_MEMORY_ALL_TYPES
for tag in tags:
self.offload_tags.remove(tag)
if GPU_MEMORY_TYPE_CUDA_GRAPH in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_CUDA_GRAPH)
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group)
......
......@@ -21,12 +21,14 @@ import inspect
import logging
import os
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import tqdm
from torch.profiler import ProfilerActivity, profile
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
......@@ -64,6 +66,7 @@ from sglang.srt.utils import (
require_mlp_tp_gather,
)
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
try:
from kt_kernel import AMXMoEWrapper
......@@ -518,7 +521,16 @@ class CudaGraphRunner:
logger.info(log_message)
def _capture_graph(self, graph, pool, stream, run_once_fn):
with self.device_module.graph(graph, pool=pool, stream=stream):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=self.model_runner.server_args.enable_memory_saver
and get_bool_env_var("SGLANG_MEMORY_SAVER_CUDA_GRAPH")
)
graph_fn = (
partial(memory_saver_adapter.cuda_graph, tag=GPU_MEMORY_TYPE_CUDA_GRAPH)
if memory_saver_adapter.enabled
else self.device_module.graph
)
with graph_fn(cuda_graph=graph, pool=pool, stream=stream):
out = run_once_fn()
return out
......
......@@ -41,6 +41,12 @@ class TorchMemorySaverAdapter(ABC):
def region(self, tag: str, enable_cpu_backup: bool = False):
raise NotImplementedError
def cuda_graph(self, **kwargs):
raise NotImplementedError
def disable(self):
raise NotImplementedError
def pause(self, tag: str):
raise NotImplementedError
......@@ -61,6 +67,12 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
def region(self, tag: str, enable_cpu_backup: bool = False):
return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
def cuda_graph(self, **kwargs):
return _memory_saver.cuda_graph(**kwargs)
def disable(self):
return _memory_saver.disable()
def pause(self, tag: str):
return _memory_saver.pause(tag=tag)
......@@ -81,6 +93,14 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
def region(self, tag: str, enable_cpu_backup: bool = False):
yield
@contextmanager
def cuda_graph(self, **kwargs):
yield
@contextmanager
def disable(self):
yield
def pause(self, tag: str):
pass
......
......@@ -25,6 +25,7 @@ configurations (tp=1, tp=2) to ensure proper memory management in distributed se
data parallel size, we test it in verl.
"""
import os
import time
import unittest
......@@ -32,7 +33,11 @@ import torch
from transformers import AutoModelForCausalLM
import sglang as sgl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constants import (
GPU_MEMORY_TYPE_CUDA_GRAPH,
GPU_MEMORY_TYPE_KV_CACHE,
GPU_MEMORY_TYPE_WEIGHTS,
)
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
......@@ -59,6 +64,8 @@ class TestReleaseMemoryOccupation(CustomTestCase):
enable_weights_cpu_backup=False,
):
"""Common setup for engine and HF model."""
os.environ["SGLANG_MEMORY_SAVER_CUDA_GRAPH"] = "1"
engine = sgl.Engine(
model_path=model_name,
random_seed=42,
......@@ -215,6 +222,7 @@ class TestReleaseMemoryOccupation(CustomTestCase):
continue
print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume")
os.environ["SGLANG_MEMORY_SAVER_CUDA_GRAPH"] = "1"
engine = sgl.Engine(
model_path=model_name,
random_seed=42,
......@@ -232,17 +240,17 @@ class TestReleaseMemoryOccupation(CustomTestCase):
)
t = time.perf_counter()
gpu_memory_usage_before_release_kv_cache = get_gpu_memory_gb()
gpu_memory_usage_before_release = get_gpu_memory_gb()
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE])
gpu_memory_usage_after_release_kv_cache = get_gpu_memory_gb()
self.assertLess(
gpu_memory_usage_after_release_kv_cache,
gpu_memory_usage_before_release_kv_cache,
gpu_memory_usage_before_release,
)
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
gpu_memory_usage_after_release_weights = get_gpu_memory_gb()
self.assertLess(
......@@ -250,32 +258,48 @@ class TestReleaseMemoryOccupation(CustomTestCase):
gpu_memory_usage_after_release_kv_cache,
)
engine.release_memory_occupation(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH])
gpu_memory_usage_after_release_cuda_graph = get_gpu_memory_gb()
self.assertLess(
gpu_memory_usage_after_release_cuda_graph,
gpu_memory_usage_after_release_weights,
)
print(f"Release took {time.perf_counter() - t:.2f}s")
print(
f"Memory: {gpu_memory_usage_before_release_kv_cache:.1f}{gpu_memory_usage_after_release_kv_cache:.1f}{gpu_memory_usage_after_release_weights:.1f} GB"
f"Memory: {gpu_memory_usage_before_release:.1f}{gpu_memory_usage_after_release_kv_cache:.1f}{gpu_memory_usage_after_release_weights:.1f} {gpu_memory_usage_after_release_cuda_graph:.1f} GB"
)
if _DEBUG_EXTRA:
time.sleep(3)
t = time.perf_counter()
gpu_memory_usage_before_resume_weights = get_gpu_memory_gb()
gpu_memory_usage_before_resume = get_gpu_memory_gb()
# gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume_weights should be close
# gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume should be close
self.assertAlmostEqual(
gpu_memory_usage_after_release_weights,
gpu_memory_usage_before_resume_weights,
gpu_memory_usage_before_resume,
delta=3.0,
)
print(f"Resume weights took {time.perf_counter() - t:.2f}s")
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH])
gpu_memory_usage_after_resume_cuda_graph = get_gpu_memory_gb()
self.assertGreater(
gpu_memory_usage_after_resume_cuda_graph,
gpu_memory_usage_before_resume,
)
engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_WEIGHTS])
gpu_memory_usage_after_resume_weights = get_gpu_memory_gb()
self.assertGreater(
gpu_memory_usage_after_resume_weights,
gpu_memory_usage_before_resume_weights,
gpu_memory_usage_after_resume_cuda_graph,
)
# Update weights from a trained model to serving engine, and then destroy the trained model
......@@ -300,7 +324,7 @@ class TestReleaseMemoryOccupation(CustomTestCase):
print(f"Resume + update took {time.perf_counter() - t:.2f}s")
print(
f"Memory: {gpu_memory_usage_before_resume_weights:.1f}{gpu_memory_usage_after_resume_weights:.1f}{gpu_memory_usage_after_loaded_hf_model:.1f}{gpu_memory_usage_after_resume_kv_cache:.1f} GB"
f"Memory: {gpu_memory_usage_before_resume:.1f}{gpu_memory_usage_after_resume_cuda_graph:.1f}{gpu_memory_usage_after_resume_weights:.1f}{gpu_memory_usage_after_loaded_hf_model:.1f}{gpu_memory_usage_after_resume_kv_cache:.1f} GB"
)
print("generate (#2)")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment