Unverified Commit 2600fc0d authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Overlapped weight offload (#8034)

parent ccd3fb94
import base64
import os
import pickle
import time
from pathlib import Path
from typing import Any, List, Optional
import torch
from sglang.srt.utils import MultiprocessingSerializer
class NaiveDistributed:
def __init__(self, rank: int, world_size: int, rendezvous: str):
self._rank = rank
self._world_size = world_size
self._operation_index = 0
self._directory = Path(rendezvous)
self._directory.mkdir(parents=True, exist_ok=True)
assert 0 <= rank < world_size
# both barrier to be safe, and as a sanity check
self.barrier()
def get_rank(self):
return self._rank
def get_world_size(self):
return self._world_size
def scatter(
self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0
):
if self._rank == src:
assert len(scatter_list) == self._world_size
else:
assert scatter_list is None
gathered_objects = self.all_gather_object(
dict(
serialized_scatter_list=[
(
None
if item_rank == src
else MultiprocessingSerializer.serialize(item)
)
for item_rank, item in enumerate(scatter_list)
]
)
if self._rank == src
else dict()
)
remote_serialized_tensor = gathered_objects[src]["serialized_scatter_list"][
self._rank
]
if self._rank == src:
assert remote_serialized_tensor is None
remote_tensor = scatter_list[self._rank]
else:
remote_tensor = MultiprocessingSerializer.deserialize(
remote_serialized_tensor
)
tensor.copy_(remote_tensor)
# avoid src tensor be deleted too early
self.barrier()
def all_gather_object(self, obj: Any) -> List[Any]:
self._operation_index += 1
text_postfix = "\n"
def _get_path(interesting_rank: int):
return (
self._directory
/ f"rank{interesting_rank}_op{self._operation_index}.txt"
)
_get_path(self._rank).write_text(
base64.b64encode(pickle.dumps(obj)).decode("utf-8") + text_postfix
)
def _read_one(interesting_rank: int):
p = _get_path(interesting_rank)
while True:
if p.exists() and (text := p.read_text()).endswith(text_postfix):
return pickle.loads(base64.b64decode(text[: -len(text_postfix)]))
time.sleep(0.001)
return [
_read_one(interesting_rank) for interesting_rank in range(self._world_size)
]
def barrier(self):
actual_objs = self.all_gather_object(self._rank)
assert actual_objs == list(range(self._world_size)), f"{actual_objs=}"
# Can have multi instances if needed
_instance: Optional[NaiveDistributed] = None
def get_naive_distributed():
assert _instance is not None
return _instance
def set_naive_distributed(instance: NaiveDistributed):
global _instance
assert _instance is None
_instance = instance
......@@ -23,8 +23,10 @@ import dataclasses
import logging
import multiprocessing as mp
import os
import random
import signal
import threading
import time
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq
......@@ -654,6 +656,11 @@ def _set_envs_and_config(server_args: ServerArgs):
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
os.environ["TRTLLM_ENABLE_PDL"] = "1"
# Can also be passed as argument
os.environ["SGLANG_RUN_ID"] = (
f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
)
# Set prometheus env vars
if server_args.enable_metrics:
set_prometheus_multiproc_dir()
......
import logging
import os
from dataclasses import dataclass
from multiprocessing import shared_memory
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
from sglang.srt.distributed.naive_distributed import get_naive_distributed
from sglang.srt.utils import check_cuda_result
logger = logging.getLogger(__name__)
class HostSharedMemoryManager:
def __init__(self, base_name: str):
self._base_name = Path(base_name)
self._operation_index = 0
self._records: List[_Record] = []
def malloc(self, *, shape, dtype):
meta_tensor = torch.empty(size=shape, dtype=dtype, device="meta")
raw = self._malloc_raw(num_bytes=meta_tensor.nbytes)
return raw.view(dtype).view(*shape)
def _malloc_raw(self, *, num_bytes: int) -> torch.Tensor:
import cuda.bindings.runtime as cuda_rt
self._operation_index += 1
shm_name = f"{self._base_name}_op{self._operation_index}"
# TODO handle dispose
if get_naive_distributed().get_rank() == 0:
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=num_bytes)
get_naive_distributed().barrier()
if get_naive_distributed().get_rank() != 0:
shm = shared_memory.SharedMemory(name=shm_name)
np_array = np.ndarray((num_bytes,), dtype=np.uint8, buffer=shm.buf)
tensor = torch.from_numpy(np_array)
check_cuda_result(
cuda_rt.cudaHostRegister(
tensor.data_ptr(), num_bytes, cuda_rt.cudaHostRegisterPortable
)
)
get_naive_distributed().barrier()
self._records.append(
_Record(
shm=shm,
np_array=np_array,
tensor=tensor,
)
)
return tensor
@dataclass
class _Record:
shm: shared_memory.SharedMemory
np_array: np.ndarray
tensor: torch.Tensor
# Can have multi instances if needed
_instance: Optional[HostSharedMemoryManager] = None
def get_host_shared_memory_manager():
assert _instance is not None
return _instance
def set_host_shared_memory_manager(instance: HostSharedMemoryManager):
global _instance
assert _instance is None
_instance = instance
......@@ -92,6 +92,7 @@ class TpModelWorker:
pp_rank=pp_rank,
pp_size=server_args.pp_size,
nccl_port=nccl_port,
dp_rank=dp_rank,
server_args=server_args,
is_draft_worker=is_draft_worker,
req_to_token_pool=req_to_token_pool,
......
......@@ -172,6 +172,7 @@ class ModelRunner:
pp_size: int,
nccl_port: int,
server_args: ServerArgs,
dp_rank: Optional[int] = None,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
......@@ -234,7 +235,7 @@ class ModelRunner:
min_per_gpu_memory = self.init_torch_distributed()
# CPU offload
set_offloader(create_offloader_from_server_args(server_args))
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
# Update deep gemm configure
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
......
......@@ -1996,6 +1996,23 @@ class DeepseekV2Model(nn.Module):
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
offloader_kwargs=dict(
submodule_accessor=lambda layer: (
layer.mlp.experts
if isinstance(layer.mlp, DeepseekV2MoE)
else layer.mlp
),
whitelist_param_names_creator=lambda module: (
[
"w13_weight",
"w2_weight",
"w13_blockscale_swizzled",
"w2_blockscale_swizzled",
]
if isinstance(module, FusedMoE)
else []
),
),
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
import logging
import os
from abc import ABC
from typing import Callable, Generator, List, Optional
import torch
from torch.func import functional_call
from sglang.srt.distributed.naive_distributed import (
NaiveDistributed,
get_naive_distributed,
set_naive_distributed,
)
from sglang.srt.host_shared_memory import (
HostSharedMemoryManager,
get_host_shared_memory_manager,
set_host_shared_memory_manager,
)
from sglang.srt.layers.parameter import ModelWeightParameter
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_pin_memory_available
from sglang.srt.utils import MultiprocessingSerializer, is_pin_memory_available
logger = logging.getLogger(__name__)
......@@ -45,11 +57,23 @@ def set_offloader(instance: BaseOffloader):
_instance = instance
def create_offloader_from_server_args(server_args: ServerArgs):
def create_offloader_from_server_args(server_args: ServerArgs, dp_rank: int):
if server_args.cpu_offload_gb > 0:
return OffloaderV1(
cpu_offload_max_bytes=int(server_args.cpu_offload_gb * 1024**3)
)
if server_args.offload_group_size > 0:
assert (
server_args.cpu_offload_gb == 0
), "V2 offload does not support cpu_offload_gb yet"
return OffloaderV2(
group_size=server_args.offload_group_size,
num_in_group=server_args.offload_num_in_group,
prefetch_step=server_args.offload_prefetch_step,
mode=server_args.offload_mode,
dp_rank=dp_rank,
dp_size=server_args.dp_size,
)
return NoopOffloader()
......@@ -120,3 +144,290 @@ class OffloaderV1(BaseOffloader):
module.forward = forward
return module
class OffloaderV2(BaseOffloader):
def __init__(
self,
group_size: int,
num_in_group: int,
prefetch_step: int,
mode: str,
dp_rank: int,
dp_size: int,
):
self.group_size = group_size
self.num_in_group = num_in_group
self.prefetch_step = prefetch_step
self.mode = mode
run_id = os.environ["SGLANG_RUN_ID"]
# Temporarily init inside Offloader, can move if other modules also need this
if self.mode in {"sharded_gpu", "shm_cpu"}:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
assert (
get_tensor_model_parallel_world_size() == 1
), "not yet support tp_size!=1"
set_naive_distributed(
NaiveDistributed(
rank=dp_rank,
world_size=dp_size,
rendezvous=f"/tmp/{run_id}",
)
)
if self.mode in {"shm_cpu"}:
set_host_shared_memory_manager(
HostSharedMemoryManager(
base_name=run_id,
)
)
self.offloaders = []
def wrap_modules(
self,
all_modules_generator: Generator[torch.nn.Module, None, None],
submodule_accessor: Optional[_SubmoduleAccessor] = None,
whitelist_param_names_creator: Optional[_WhitelistParamNamesCreator] = None,
):
assert len(self.offloaders) == 0, "should only call wrap_modules once"
alt_stream = torch.cuda.Stream()
all_modules = []
offload_submodules = []
for module_index, module in enumerate(all_modules_generator):
all_modules.append(module)
if module_index % self.group_size >= self.group_size - self.num_in_group:
submodule = submodule_accessor(module)
whitelist_param_names = whitelist_param_names_creator(submodule)
logger.info(
f"[offloader] offload {module_index=} submodule={type(submodule)} params={whitelist_param_names} memory_allocated={torch.cuda.memory_allocated()}"
)
offload_submodules.append(submodule)
self.offloaders.append(
_ModuleOffloader(
mode=self.mode,
module=submodule,
alt_stream=alt_stream,
whitelist_param_names=whitelist_param_names,
)
)
for index, module in enumerate(offload_submodules):
_hook_module_forward_for_offloader(
index=index,
module=module,
offloaders=self.offloaders,
prefetch_step=self.prefetch_step,
)
return all_modules
def post_init(self):
for offloader in self.offloaders:
offloader.post_init()
for i in range(self.prefetch_step):
self.offloaders[i].start_onload()
def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
def _on_forward_end():
offloaders[(index + prefetch_step) % len(offloaders)].start_onload()
offloaders[index].offload()
_hook_module_forward_raw(
module,
on_forward_end=_on_forward_end,
get_parameter_and_buffer_dicts=lambda: offloaders[
index
].wait_and_get_device_tensors(),
)
def _hook_module_forward_raw(module, on_forward_end, get_parameter_and_buffer_dicts):
original_forward = module.forward
def forward(*args, **kwargs):
module.forward = original_forward
output = functional_call(
module, get_parameter_and_buffer_dicts(), args=args, kwargs=kwargs
)
on_forward_end()
module.forward = forward
return output
module.forward = forward
class _ModuleOffloader(ABC):
def __init__(
self,
mode: str,
module: torch.nn.Module,
alt_stream: torch.cuda.Stream,
whitelist_param_names: List[str],
):
self.mode = mode
self.module = module
self.device = next(module.parameters()).device
self.alt_stream = alt_stream
assert self.device != torch.device(
"cpu"
), "not handled device=cpu case yet (should skip this tensor)"
self._device_tensors = None
self._load_event = None
param_dict = dict(self.module.named_parameters())
assert all(
name in param_dict for name in whitelist_param_names
), f"{whitelist_param_names=} {list(param_dict.keys())=}"
self._param_offloaders = {
name: _BaseParamOffloader.create(mode, module=module, param_name=name)
for name in whitelist_param_names
}
def post_init(self):
for name, param_offloader in self._param_offloaders.items():
param_offloader.post_init()
def start_onload(self):
self.alt_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.alt_stream):
self._device_tensors = self._create_device_tensors()
self._load_event = torch.cuda.Event()
self._load_event.record()
def offload(self):
self._device_tensors = None
self._load_event = None
def wait_and_get_device_tensors(self):
assert self._device_tensors is not None
self._load_event.wait()
return self._device_tensors
def _create_device_tensors(self):
return {k: v.create_device_tensor() for k, v in self._param_offloaders.items()}
class _BaseParamOffloader(ABC):
@staticmethod
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
return {
"cpu": _CpuParamOffloader,
"shm_cpu": _ShmCpuParamOffloader,
"sharded_gpu": _ShardedGpuParamOffloader,
}[mode](**kwargs)
def __init__(self, module, param_name):
self._module = module
self._param_name = param_name
@property
def _param(self):
return getattr(self._module, self._param_name)
def post_init(self):
pass
def create_device_tensor(self):
raise NotImplementedError
class _CpuParamOffloader(_BaseParamOffloader):
def __init__(self, module, param_name):
super().__init__(module, param_name)
_move_param_to_cpu(self._param, pin_memory=True)
def create_device_tensor(self):
return self._param.to("cuda", non_blocking=True)
class _ShmCpuParamOffloader(_BaseParamOffloader):
def __init__(self, module, param_name):
super().__init__(module, param_name)
self._rank = get_naive_distributed().get_rank()
self._world_size = get_naive_distributed().get_world_size()
from sglang.srt.distributed import get_tensor_model_parallel_world_size
assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
assert (
self._param.data.is_contiguous()
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
self.shm_cpu_data = get_host_shared_memory_manager().malloc(
shape=self._param.shape, dtype=self._param.dtype
)
if self._rank == 0:
self.shm_cpu_data.copy_(self._param.data.to("cpu"))
self._param.data = self.shm_cpu_data
else:
_move_param_to_meta(self._module, self._param_name)
get_naive_distributed().barrier()
def post_init(self):
if self._rank == 0:
assert (
self.shm_cpu_data.data_ptr() == self._param.data.data_ptr()
), f"{self.shm_cpu_data.data_ptr()=} {self._param.data.data_ptr()=} {self.shm_cpu_data=} {self._param.data=}"
_move_param_to_meta(self._module, self._param_name)
def create_device_tensor(self):
return self.shm_cpu_data.to("cuda", non_blocking=True)
def _move_param_to_cpu(param, pin_memory: bool):
cpu_data = _empty_strided_like(
param.data,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(param.data)
param.data = cpu_data
def _move_param_to_meta(module, param_name):
old_param = getattr(module, param_name)
old_param_type = type(old_param)
new_data = old_param.data.to("meta")
if old_param_type == ModelWeightParameter:
# manually checked how `w13_weight` and `w2_weight` are constructed
new_param = ModelWeightParameter(
data=new_data,
**{
k: getattr(old_param, k)
for k in ["input_dim", "output_dim", "weight_loader"]
},
)
elif old_param_type == torch.nn.Parameter:
new_param = torch.nn.Parameter(
data=new_data,
requires_grad=False,
)
else:
raise ValueError(f"Unknown {old_param_type=} {old_param=}")
setattr(module, param_name, new_param)
def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
return torch.empty_strided(
size=x.size(),
stride=x.stride(),
dtype=x.dtype,
layout=x.layout,
device=device,
pin_memory=pin_memory,
)
......@@ -85,7 +85,6 @@ class ServerArgs:
max_prefill_tokens: int = 16384
schedule_policy: str = "fcfs"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
page_size: Optional[int] = None
hybrid_kvcache_ratio: Optional[float] = None
swa_full_tokens_ratio: float = 0.8
......@@ -226,6 +225,13 @@ class ServerArgs:
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# Offloading
cpu_offload_gb: int = 0
offload_group_size: int = -1
offload_num_in_group: int = 1
offload_prefetch_step: int = 1
offload_mode: str = "cpu"
# Optimization/debug options
disable_radix_cache: bool = False
cuda_graph_max_bs: Optional[int] = None
......@@ -976,12 +982,6 @@ class ServerArgs:
default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)
parser.add_argument(
"--cpu-offload-gb",
type=int,
default=ServerArgs.cpu_offload_gb,
help="How many GBs of RAM to reserve for CPU offloading.",
)
parser.add_argument(
"--page-size",
type=int,
......@@ -1683,6 +1683,38 @@ class ServerArgs:
help="The type of heavy channels in double sparsity attention",
)
# Offloading
parser.add_argument(
"--cpu-offload-gb",
type=int,
default=ServerArgs.cpu_offload_gb,
help="How many GBs of RAM to reserve for CPU offloading.",
)
parser.add_argument(
"--offload-group-size",
type=int,
default=ServerArgs.offload_group_size,
help="Number of layers per group in offloading.",
)
parser.add_argument(
"--offload-num-in-group",
type=int,
default=ServerArgs.offload_num_in_group,
help="Number of layers to be offloaded within a group.",
)
parser.add_argument(
"--offload-prefetch-step",
type=int,
default=ServerArgs.offload_prefetch_step,
help="Steps to prefetch in offloading.",
)
parser.add_argument(
"--offload-mode",
type=str,
default=ServerArgs.offload_mode,
help="Mode of offloading.",
)
# Optimization/debug options
parser.add_argument(
"--disable-radix-cache",
......
......@@ -2954,3 +2954,13 @@ class ConcurrentCounter:
@lru_cache(maxsize=1)
def is_triton_kernels_available() -> bool:
return importlib.util.find_spec("triton_kernels") is not None
def check_cuda_result(raw_output):
import cuda.bindings.runtime as cuda_rt
err, *results = raw_output
if err != cuda_rt.cudaError_t.cudaSuccess:
raise Exception(f"CUDA error: {err}")
return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment