Unverified Commit 031f64aa authored by Ata Fatahi's avatar Ata Fatahi Committed by GitHub
Browse files

Add e2e test for multi instance multi stage memory release/resume occupuation (#7208)


Signed-off-by: default avatarAta Fatahi <immrata@gmail.com>
parent 3d7cdb2e
......@@ -160,6 +160,7 @@ suites = {
"per-commit-4-gpu": [
TestFile("test_local_attn.py", 250),
TestFile("test_pp_single_node.py", 150),
TestFile("test_multi_instance_release_memory_occupation.py", 64),
],
"per-commit-4-gpu-amd": [
TestFile("test_pp_single_node.py", 150),
......
import multiprocessing
import os
import subprocess
import traceback
import unittest
from multiprocessing import Process
from typing import Iterable, Tuple
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.engine import Engine as SglangEngine
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
CustomTestCase,
find_available_port,
)
TEST_SUITE = dict(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
mem_fraction_static=0.85,
dp_size=2,
tp_size=2,
)
class EngineWrapper:
"""
A wrapper around Sglang engine to mock multi instance cases such as RL traing.
"""
def __init__(
self, model_path, random_seed, mem_fraction_static, device_mesh_cpu, base_gpu_id
):
self._device_mesh_cpu = device_mesh_cpu
self._tp_rank = device_mesh_cpu.get_local_rank()
self._rank = device_mesh_cpu.get_rank()
self._tp_size = device_mesh_cpu.size()
tp_size_per_node = self._tp_size
node_rank = self._tp_rank // tp_size_per_node
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
engine_kwargs = dict(
model_path=model_path,
random_seed=random_seed,
mem_fraction_static=mem_fraction_static,
base_gpu_id=base_gpu_id,
enable_memory_saver=True,
tp_size=self._tp_size,
node_rank=node_rank,
nnodes=1,
)
self._engine = None
if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = SglangEngine(**engine_kwargs)
dist.barrier(group=self._device_mesh_cpu.get_group())
def update_weights_from_tensor(
self, named_tensors: Iterable[Tuple[str, torch.Tensor]]
):
if self._tp_rank == 0:
self._engine.update_weights_from_tensor(list(named_tensors))
self._engine.flush_cache()
dist.barrier(group=self._device_mesh_cpu.get_group())
def release_memory_occupation(self, tags):
if self._tp_rank == 0:
self._engine.release_memory_occupation(tags)
dist.barrier(group=self._device_mesh_cpu.get_group())
def resume_memory_occupation(self, tags):
if self._tp_rank == 0:
self._engine.resume_memory_occupation(tags)
dist.barrier(group=self._device_mesh_cpu.get_group())
def shutdown(self):
if self._tp_rank == 0:
self._engine.shutdown()
dist.barrier(group=self._device_mesh_cpu.get_group())
def get_gpu_memory_gb(gpu_id=0):
return torch.cuda.device_memory_used() / 1024**3
class TestMultiInstanceReleaseMemoryOccupation(CustomTestCase):
@classmethod
def setUpClass(cls):
multiprocessing.set_start_method("spawn")
def test_multi_instance_release_memory_occupation(self):
master_port = find_available_port(23456)
dp_size = TEST_SUITE["dp_size"]
tp_size = TEST_SUITE["tp_size"]
world_size = dp_size * tp_size
processes = []
output_reader, output_writer = multiprocessing.Pipe(duplex=False)
for rank in range(world_size):
p = Process(
target=_run_sglang_subprocess,
kwargs=dict(
rank=rank,
dp_size=dp_size,
tp_size=tp_size,
model_path=TEST_SUITE["model_path"],
master_port=master_port,
output_writer=output_writer,
mem_fraction_static=TEST_SUITE["mem_fraction_static"],
),
)
p.start()
processes.append(p)
for _ in range(world_size):
self.assertTrue(
output_reader.recv(), f"Subprocess fail. Check the logs above."
)
for p in processes:
p.join()
def _run_sglang_subprocess(
rank: int,
dp_size: int,
tp_size: int,
model_path: str,
master_port: int,
output_writer,
mem_fraction_static: float,
):
engine = None
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
dist.init_process_group(
rank=rank,
device_id=torch.device(f"cuda:{rank}"),
world_size=dp_size * tp_size,
)
torch.cuda.set_device(rank)
base_gpu_id = rank // tp_size * tp_size
mesh_kwargs = dict(
mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
)
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
print(
f"subprocess[{rank=},{base_gpu_id=},{rank=},{tp_size=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
)
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage before starting Engine: {_mem_usage}")
engine = EngineWrapper(
model_path=model_path,
random_seed=42,
mem_fraction_static=mem_fraction_static,
device_mesh_cpu=inference_device_mesh_cpu["tp"],
base_gpu_id=base_gpu_id,
)
print(f"subprocess[{rank=}] {engine=}", flush=True)
# 1 - release kv cache
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage before releasing Sgl KV cache: {_mem_usage}")
engine.release_memory_occupation(tags=["kv_cache"])
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage < _mem_usage
), f"Memory usage after releasing KV cache must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
# 2 - release sglang weights
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage before releasing Sgl weights: {_mem_usage}")
engine.release_memory_occupation(tags=["weights"])
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage < _mem_usage
), f"Memory usage after releasing weights must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
# 3 - load hf model
_mem_usage = get_gpu_memory_gb(rank)
print(
f"GPU{rank} Memory usage after releasing Sgl weights and kv cache: {_mem_usage}"
)
hf_model = AutoModelForCausalLM.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE,
torch_dtype="bfloat16",
device_map=f"cuda:{rank}",
trust_remote_code=True,
).cuda()
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage > _mem_usage
), f"Memory usage after loading hf model must be increased! before: {_mem_usage} vs after: {_curr_usage}"
# 4 - resume sglang weights and update the weights
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage after loading hf model: {_mem_usage}")
engine.resume_memory_occupation(tags=["weights"])
engine.update_weights_from_tensor(
named_tensors=list(hf_model.named_parameters())
)
# 5 - release hf model
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage after resuming Sgl weights: {_mem_usage}")
del hf_model
torch.cuda.empty_cache()
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage < _mem_usage
), f"Memory usage after releasing hf model must be reduced! before: {_mem_usage} vs after: {_curr_usage}"
# 6 - resume slgang kv cache
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage after releasing hf model: {_mem_usage}")
engine.resume_memory_occupation(tags=["kv_cache"])
_curr_usage = get_gpu_memory_gb(rank)
assert (
_curr_usage > _mem_usage
), f"Memory usage after resuming kv cache must be increased! before: {_mem_usage} vs after: {_curr_usage}"
# 7 - Final checking!
_mem_usage = get_gpu_memory_gb(rank)
print(f"GPU{rank} Memory usage after resuming Sgl KV cache: {_mem_usage}")
execution_ok = True
except Exception as e:
print(f"subprocess[{rank=}] has error: {e}", flush=True)
traceback.print_exc()
execution_ok = False
output_writer.send(execution_ok)
output_writer.close()
if engine:
engine.shutdown()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment