"tests/vscode:/vscode.git/clone" did not exist on "e162310627e20330e6d779a30c47f29f7b3452cc"
Unverified Commit d0510f08 authored by Sai Enduri's avatar Sai Enduri Committed by GitHub
Browse files

Revert "Fix different device type adjustment in PP" (#8141)

parent 9d33fcfb
...@@ -699,14 +699,14 @@ class GroupCoordinator: ...@@ -699,14 +699,14 @@ class GroupCoordinator:
) )
# Serialize object to tensor and get the size as well # Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).to( object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
device=self.device device=torch.cuda.current_device()
) )
size_tensor = torch.tensor( size_tensor = torch.tensor(
[object_tensor.numel()], [object_tensor.numel()],
dtype=torch.long, dtype=torch.long,
device=self.device, device=torch.cuda.current_device(),
) )
# Send object size # Send object size
...@@ -731,7 +731,9 @@ class GroupCoordinator: ...@@ -731,7 +731,9 @@ class GroupCoordinator:
src != self.rank_in_group src != self.rank_in_group
), "Invalid source rank. Source rank is the same as the current rank." ), "Invalid source rank. Source rank is the same as the current rank."
size_tensor = torch.empty(1, dtype=torch.long, device=self.device) size_tensor = torch.empty(
1, dtype=torch.long, device=torch.cuda.current_device()
)
# Receive object size # Receive object size
rank_size = torch.distributed.recv( rank_size = torch.distributed.recv(
...@@ -742,7 +744,7 @@ class GroupCoordinator: ...@@ -742,7 +744,7 @@ class GroupCoordinator:
object_tensor = torch.empty( # type: ignore[call-overload] object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type] size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8, dtype=torch.uint8,
device=self.device, device=torch.cuda.current_device(),
) )
rank_object = torch.distributed.recv( rank_object = torch.distributed.recv(
......
...@@ -975,7 +975,6 @@ class Scheduler( ...@@ -975,7 +975,6 @@ class Scheduler(
self.world_group.device_group, self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset,
(self.pp_rank + 1) * self.tp_size + dp_offset, (self.pp_rank + 1) * self.tp_size + dp_offset,
device=self.device,
) )
# send out proxy tensors to the next stage # send out proxy tensors to the next stage
...@@ -1024,7 +1023,6 @@ class Scheduler( ...@@ -1024,7 +1023,6 @@ class Scheduler(
self.world_group.device_group, self.world_group.device_group,
(self.pp_rank - 1) * self.tp_size + dp_offset, (self.pp_rank - 1) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset,
device=self.device,
) )
else: else:
recv_reqs = None recv_reqs = None
...@@ -1055,7 +1053,6 @@ class Scheduler( ...@@ -1055,7 +1053,6 @@ class Scheduler(
self.attn_tp_group.rank, self.attn_tp_group.rank,
self.attn_tp_cpu_group, self.attn_tp_cpu_group,
src=self.attn_tp_group.ranks[0], src=self.attn_tp_group.ranks[0],
device=self.device,
) )
if self.tp_size != 1: if self.tp_size != 1:
control_reqs = broadcast_pyobj( control_reqs = broadcast_pyobj(
...@@ -1063,7 +1060,6 @@ class Scheduler( ...@@ -1063,7 +1060,6 @@ class Scheduler(
self.tp_group.rank, self.tp_group.rank,
self.tp_cpu_group, self.tp_cpu_group,
src=self.tp_group.ranks[0], src=self.tp_group.ranks[0],
device=self.device,
) )
recv_reqs = work_reqs + control_reqs recv_reqs = work_reqs + control_reqs
elif self.tp_size != 1: elif self.tp_size != 1:
...@@ -1072,7 +1068,6 @@ class Scheduler( ...@@ -1072,7 +1068,6 @@ class Scheduler(
self.tp_group.rank, self.tp_group.rank,
self.tp_cpu_group, self.tp_cpu_group,
src=self.tp_group.ranks[0], src=self.tp_group.ranks[0],
device=self.device,
) )
return recv_reqs return recv_reqs
......
...@@ -144,7 +144,6 @@ class TpModelWorker: ...@@ -144,7 +144,6 @@ class TpModelWorker:
self.tp_size * self.pp_rank + tp_rank, self.tp_size * self.pp_rank + tp_rank,
self.world_group.cpu_group, self.world_group.cpu_group,
src=self.world_group.ranks[0], src=self.world_group.ranks[0],
device=self.device,
)[0] )[0]
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
......
...@@ -1100,15 +1100,15 @@ def broadcast_pyobj( ...@@ -1100,15 +1100,15 @@ def broadcast_pyobj(
rank: int, rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None, dist_group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0, src: int = 0,
device: Optional[str] = None, force_cpu_device: bool = True,
): ):
"""Broadcast inputs from src rank to all other ranks with torch.dist backend. """Broadcast inputs from src rank to all other ranks with torch.dist backend.
The `rank` here refer to the source rank on global process group (regardless The `rank` here refer to the source rank on global process group (regardless
of dist_group argument). of dist_group argument).
""" """
device = torch.device(
if device is None: "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
device = get_device() )
if rank == src: if rank == src:
if len(data) == 0: if len(data) == 0:
...@@ -1148,38 +1148,44 @@ def point_to_point_pyobj( ...@@ -1148,38 +1148,44 @@ def point_to_point_pyobj(
group: Optional[torch.distributed.ProcessGroup] = None, group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0, src: int = 0,
dst: int = 1, dst: int = 1,
device: Optional[str] = None,
): ):
"""Send data from src to dst in group using DeviceToDevice communication.""" """Send data from src to dst in group using DeviceToDevice communication."""
if device is None:
device = get_device()
if rank == src: if rank == src:
if len(data) == 0: if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long, device=device) tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_size, dst=dst, group=group)
else: else:
serialized_data = pickle.dumps(data) serialized_data = pickle.dumps(data)
size = len(serialized_data) size = len(serialized_data)
tensor_data = torch.ByteTensor( tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8) np.frombuffer(serialized_data, dtype=np.uint8)
).to( ).cuda(
device=device device=torch.cuda.current_device()
) # Move to Device ) # Move to GPU
tensor_size = torch.tensor([size], dtype=torch.long, device=device) tensor_size = torch.tensor(
[size], dtype=torch.long, device=torch.cuda.current_device()
)
dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_size, dst=dst, group=group)
dist.send(tensor_data, dst=dst, group=group) dist.send(tensor_data, dst=dst, group=group)
return data return data
elif rank == dst: elif rank == dst:
tensor_size = torch.tensor([0], dtype=torch.long, device=device) tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.recv(tensor_size, src=src, group=group) dist.recv(tensor_size, src=src, group=group)
size = tensor_size.item() size = tensor_size.item()
if size == 0: if size == 0:
return [] return []
tensor_data = torch.empty(size, dtype=torch.uint8, device=device) tensor_data = torch.empty(
size, dtype=torch.uint8, device=torch.cuda.current_device()
)
dist.recv(tensor_data, src=src, group=group) dist.recv(tensor_data, src=src, group=group)
serialized_data = bytes( serialized_data = bytes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment