Unverified Commit 00991723 authored by TianyuZhang1214's avatar TianyuZhang1214 Committed by GitHub
Browse files

feat: use D2D instead of H2H in pp (#7673)


Co-authored-by: default avataralpha-baby <fujianhao1997@qq.com>
parent 264dc6e7
...@@ -699,18 +699,25 @@ class GroupCoordinator: ...@@ -699,18 +699,25 @@ class GroupCoordinator:
) )
# Serialize object to tensor and get the size as well # Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
device=torch.cuda.current_device()
)
size_tensor = torch.tensor( size_tensor = torch.tensor(
[object_tensor.numel()], dtype=torch.long, device="cpu" [object_tensor.numel()],
dtype=torch.long,
device=torch.cuda.current_device(),
) )
# Send object size # Send object size
torch.distributed.send(
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) size_tensor, dst=self.ranks[dst], group=self.device_group
)
# Send object # Send object
torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) torch.distributed.send(
object_tensor, dst=self.ranks[dst], group=self.device_group
)
return None return None
...@@ -724,29 +731,31 @@ class GroupCoordinator: ...@@ -724,29 +731,31 @@ class GroupCoordinator:
src != self.rank_in_group src != self.rank_in_group
), "Invalid source rank. Source rank is the same as the current rank." ), "Invalid source rank. Source rank is the same as the current rank."
size_tensor = torch.empty(1, dtype=torch.long, device="cpu") size_tensor = torch.empty(
1, dtype=torch.long, device=torch.cuda.current_device()
)
# Receive object size # Receive object size
rank_size = torch.distributed.recv( rank_size = torch.distributed.recv(
size_tensor, src=self.ranks[src], group=self.cpu_group size_tensor, src=self.ranks[src], group=self.device_group
) )
# Tensor to receive serialized objects into. # Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload] object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type] size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8, dtype=torch.uint8,
device="cpu", device=torch.cuda.current_device(),
) )
rank_object = torch.distributed.recv( rank_object = torch.distributed.recv(
object_tensor, src=self.ranks[src], group=self.cpu_group object_tensor, src=self.ranks[src], group=self.device_group
) )
assert ( assert (
rank_object == rank_size rank_object == rank_size
), "Received object sender rank does not match the size sender rank." ), "Received object sender rank does not match the size sender rank."
obj = pickle.loads(object_tensor.numpy().tobytes()) obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
return obj return obj
...@@ -857,14 +866,16 @@ class GroupCoordinator: ...@@ -857,14 +866,16 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, dict tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}" ), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict) metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory. # Note: While switching to Device-to-Device (D2D) would introduce an extra
# `send_object_list` has serialization & deserialization, # Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
# all happening on CPU. Therefore, we can use the CPU group. # show better overall transmission performance with D2D due to:
# 1. Superior D2D transfer bandwidth
# 2. Ability to overlap send and recv operations
# Thus the net performance gain justifies this approach.
self.send_object(metadata_list, dst=dst) self.send_object(metadata_list, dst=dst)
for tensor in tensor_list: for tensor in tensor_list:
if tensor.numel() == 0: if tensor.numel() == 0:
......
...@@ -928,7 +928,7 @@ class Scheduler( ...@@ -928,7 +928,7 @@ class Scheduler(
point_to_point_pyobj( point_to_point_pyobj(
recv_reqs, recv_reqs,
self.pp_rank * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset,
self.world_group.cpu_group, self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset,
(self.pp_rank + 1) * self.tp_size + dp_offset, (self.pp_rank + 1) * self.tp_size + dp_offset,
) )
...@@ -975,7 +975,7 @@ class Scheduler( ...@@ -975,7 +975,7 @@ class Scheduler(
recv_reqs = point_to_point_pyobj( recv_reqs = point_to_point_pyobj(
[], [],
self.pp_rank * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset,
self.world_group.cpu_group, self.world_group.device_group,
(self.pp_rank - 1) * self.tp_size + dp_offset, (self.pp_rank - 1) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset,
) )
......
...@@ -1000,36 +1000,48 @@ def point_to_point_pyobj( ...@@ -1000,36 +1000,48 @@ def point_to_point_pyobj(
src: int = 0, src: int = 0,
dst: int = 1, dst: int = 1,
): ):
"""Send data from src to dst in group.""" """Send data from src to dst in group using DeviceToDevice communication."""
if rank == src: if rank == src:
if len(data) == 0: if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long) tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_size, dst=dst, group=group)
else: else:
serialized_data = pickle.dumps(data) serialized_data = pickle.dumps(data)
size = len(serialized_data) size = len(serialized_data)
tensor_data = torch.ByteTensor( tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8) np.frombuffer(serialized_data, dtype=np.uint8)
).cuda(
device=torch.cuda.current_device()
) # Move to GPU
tensor_size = torch.tensor(
[size], dtype=torch.long, device=torch.cuda.current_device()
) )
tensor_size = torch.tensor([size], dtype=torch.long)
dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_size, dst=dst, group=group)
dist.send(tensor_data, dst=dst, group=group) dist.send(tensor_data, dst=dst, group=group)
return data return data
elif rank == dst: elif rank == dst:
tensor_size = torch.tensor([0], dtype=torch.long) tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.recv(tensor_size, src=src, group=group) dist.recv(tensor_size, src=src, group=group)
size = tensor_size.item() size = tensor_size.item()
if size == 0: if size == 0:
return [] return []
tensor_data = torch.empty(size, dtype=torch.uint8) tensor_data = torch.empty(
size, dtype=torch.uint8, device=torch.cuda.current_device()
)
dist.recv(tensor_data, src=src, group=group) dist.recv(tensor_data, src=src, group=group)
serialized_data = bytes(tensor_data.cpu().numpy()) serialized_data = bytes(
tensor_data.cpu().numpy()
) # Move back to host for deserialization
data = pickle.loads(serialized_data) data = pickle.loads(serialized_data)
return data return data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment