Unverified Commit 8b48496a authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Revert "Revert "Add simple CPU offloading support"" (#2253)


Co-authored-by: default avatarJani Monoses <jani.monoses@gmail.com>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent 4057ea82
......@@ -61,6 +61,7 @@ from sglang.srt.utils import (
is_hip,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)
logger = logging.getLogger(__name__)
......@@ -145,6 +146,8 @@ class ModelRunner:
}
)
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
# Init components
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
......
......@@ -38,6 +38,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token
......@@ -267,11 +268,15 @@ class Gemma2Model(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
Gemma2DecoderLayer(layer_id, config, cache_config, quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Gemma2DecoderLayer(
layer_id=idx,
config=config,
cache_config=cache_config,
quant_config=quant_config,
),
prefix="",
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
......@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
class LlamaMLP(nn.Module):
......@@ -255,13 +256,12 @@ class LlamaModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
)
for i in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: LlamaDecoderLayer(
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
),
prefix="model.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
......@@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
class OlmoAttention(nn.Module):
......@@ -220,11 +221,13 @@ class OlmoModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[
OlmoDecoderLayer(config, layer_id, quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: OlmoDecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
),
)
self.norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=False, bias=False
......
......@@ -48,6 +48,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
class OlmoeMoE(nn.Module):
......@@ -261,11 +262,13 @@ class OlmoeModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
OlmoeDecoderLayer(config, layer_id, quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: OlmoeDecoderLayer(
config=config,
quant_config=quant_config,
layer_id=idx,
),
)
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
......
......@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers
Qwen2Config = None
......@@ -230,11 +231,13 @@ class Qwen2Model(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
Qwen2DecoderLayer(config, i, quant_config=quant_config)
for i in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Qwen2DecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
......@@ -62,6 +62,7 @@ class ServerArgs:
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
# Other runtime options
tp_size: int = 1
......@@ -367,6 +368,13 @@ class ServerArgs:
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)
parser.add_argument(
"--cpu-offload-gb",
type=int,
default=ServerArgs.cpu_offload_gb,
help="How many GBs of RAM to reserve for CPU offloading",
)
# Other runtime options
parser.add_argument(
"--tensor-parallel-size",
......
......@@ -32,7 +32,7 @@ import time
import warnings
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
import numpy as np
import psutil
......@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import (
......@@ -192,6 +193,94 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
return free_gpu_memory / (1 << 30)
def is_pin_memory_available() -> bool:
return torch.cuda.is_available()
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device
if device == torch.device("cpu"):
return module
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module
pin_memory = is_pin_memory_available()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True
if offloaded_parameters:
original_forward = module.forward
def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module, device_state, args=args, kwargs=kwargs)
module.forward = forward
return output
module.forward = forward
return module
class LayerFn(Protocol):
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str = "",
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function"""
modules = torch.nn.ModuleList(
[
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
for idx in range(num_hidden_layers)
]
)
return modules
def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
......
......@@ -152,7 +152,37 @@ class TestSRTEngine(unittest.TestCase):
self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))
def test_7_engine_offline_throughput(self):
def test_7_engine_cpu_offload(self):
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
sampling_params = {"temperature": 0, "max_new_tokens": 8}
engine = sgl.Engine(
model_path=model_path,
random_seed=42,
max_total_tokens=128,
)
out1 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()
engine = sgl.Engine(
model_path=model_path,
random_seed=42,
max_total_tokens=128,
cpu_offload_gb=3,
)
out2 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()
print("==== Answer 1 ====")
print(out1)
print("==== Answer 2 ====")
print(out2)
self.assertEqual(out1, out2)
def test_8_engine_offline_throughput(self):
server_args = ServerArgs(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment