Unverified Commit 71a7f1d8 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Offload tensors by sharding on GPU (#9536)

parent 433266c1
...@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC): ...@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC):
@staticmethod @staticmethod
def create(mode: str, **kwargs) -> "_BaseParamOffloader": def create(mode: str, **kwargs) -> "_BaseParamOffloader":
return { return {
"meta": _MetaParamOffloader,
"cpu": _CpuParamOffloader, "cpu": _CpuParamOffloader,
"shm_cpu": _ShmCpuParamOffloader, "shm_cpu": _ShmCpuParamOffloader,
"sharded_gpu": _ShardedGpuParamOffloader, "sharded_gpu": _ShardedGpuParamOffloader,
...@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC): ...@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC):
raise NotImplementedError raise NotImplementedError
class _MetaParamOffloader(_BaseParamOffloader):
"""Usually used for debugging."""
def __init__(self, module, param_name):
super().__init__(module, param_name)
_move_param_to_meta(module, param_name)
def create_device_tensor(self):
return torch.empty_like(self._param.data, device="cuda")
class _CpuParamOffloader(_BaseParamOffloader): class _CpuParamOffloader(_BaseParamOffloader):
def __init__(self, module, param_name): def __init__(self, module, param_name):
super().__init__(module, param_name) super().__init__(module, param_name)
...@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False): ...@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
device=device, device=device,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
# ----------------------------------------- ShardedGpu ------------------------------------------------------
# TODO unify with ShmCpu mode
class _ShardedGpuParamOffloader(_BaseParamOffloader):
def __init__(self, module, param_name):
super().__init__(module, param_name)
self._rank = get_naive_distributed().get_rank()
self._world_size = get_naive_distributed().get_world_size()
from sglang.srt.distributed import get_tensor_model_parallel_world_size
assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
assert (
self._param.data.is_contiguous()
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
if self._rank == 0:
_move_param_to_cpu(self._param, pin_memory=True)
else:
_move_param_to_meta(self._module, self._param_name)
self.sharded_param_handles = None
def post_init(self):
# check again since it may be changed
assert (
self._param.data.is_contiguous()
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
scatter_src = self._param.data
logger.info(
f"[offloader] post_init {scatter_src.nbytes=} {scatter_src.dtype=} {scatter_src.shape=} {torch.cuda.memory_allocated()=}"
)
if self._rank == 0:
scatter_src = scatter_src.to("cuda")
scatter_list = _even_chunk(scatter_src, self._world_size)
sharded_param = torch.empty(
scatter_list[0].shape, dtype=scatter_list[0].dtype, device="cuda"
)
self.sharded_param_handles = _create_shared_buffer_tensors(
local_tensor=sharded_param
)
get_naive_distributed().scatter(
sharded_param, scatter_list if self._rank == 0 else None
)
_move_param_to_meta(self._module, self._param_name)
def create_device_tensor(self):
output = _empty_strided_like(self._param, device="cuda")
output_chunks = output.chunk(self._world_size)
for index in range(self._world_size):
src_rank = (self._rank + index) % self._world_size
src_buf = self.sharded_param_handles[src_rank]
output_chunks[src_rank].copy_(src_buf)
return output
def _even_chunk(x: torch.Tensor, chunks: int):
assert x.shape[0] % chunks == 0, f"{x.shape=} {chunks=}"
return list(x.chunk(chunks))
def _create_shared_buffer_tensors(local_tensor: torch.Tensor) -> List[torch.Tensor]:
self_rank = get_naive_distributed().get_rank()
world_size = get_naive_distributed().get_world_size()
object_list = get_naive_distributed().all_gather_object(
dict(
dup_serialized_local_tensor=[
(
None
if interesting_rank == self_rank
else MultiprocessingSerializer.serialize(local_tensor)
)
for interesting_rank in range(world_size)
]
)
)
output_tensors = []
for output_rank in range(world_size):
remote_serialized_tensor = object_list[output_rank][
"dup_serialized_local_tensor"
][self_rank]
if output_rank == self_rank:
assert remote_serialized_tensor is None
output_tensors.append(local_tensor)
else:
output_tensors.append(
MultiprocessingSerializer.deserialize(remote_serialized_tensor)
)
return output_tensors
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment