Unverified Commit 357fb2db authored by Junrong Lin's avatar Junrong Lin Committed by GitHub
Browse files

fix: fix broadcast_pyobj breaking VerlEngine (#5997)

parent 95c231e5
...@@ -37,6 +37,7 @@ class VerlEngine: ...@@ -37,6 +37,7 @@ class VerlEngine:
monkey_patch_torch_reductions() monkey_patch_torch_reductions()
self._device_mesh_cpu = device_mesh_cpu self._device_mesh_cpu = device_mesh_cpu
self._tp_rank = device_mesh_cpu.get_local_rank() self._tp_rank = device_mesh_cpu.get_local_rank()
self._rank = device_mesh_cpu.get_rank()
self._tp_size = device_mesh_cpu.size() self._tp_size = device_mesh_cpu.size()
tp_size_per_node = self._tp_size // nnodes tp_size_per_node = self._tp_size // nnodes
node_rank = self._tp_rank // tp_size_per_node node_rank = self._tp_rank // tp_size_per_node
...@@ -114,7 +115,7 @@ class VerlEngine: ...@@ -114,7 +115,7 @@ class VerlEngine:
# Most naive implementation, can extract tensor and send via gloo if too slow # Most naive implementation, can extract tensor and send via gloo if too slow
[output] = broadcast_pyobj( [output] = broadcast_pyobj(
data=[output], data=[output],
rank=self._tp_rank, rank=self._rank,
dist_group=self._device_mesh_cpu.get_group(), dist_group=self._device_mesh_cpu.get_group(),
src=self._device_mesh_cpu.mesh[0].item(), src=self._device_mesh_cpu.mesh[0].item(),
force_cpu_device=False, force_cpu_device=False,
...@@ -157,7 +158,7 @@ class VerlEngine: ...@@ -157,7 +158,7 @@ class VerlEngine:
) )
if self._tp_rank == 0: if self._tp_rank == 0:
self._engine.tokenizer_manager.flush_cache() self._engine.flush_cache()
def release_memory_occupation(self): def release_memory_occupation(self):
if self._tp_rank == 0: if self._tp_rank == 0:
......
...@@ -897,7 +897,10 @@ def broadcast_pyobj( ...@@ -897,7 +897,10 @@ def broadcast_pyobj(
src: int = 0, src: int = 0,
force_cpu_device: bool = True, force_cpu_device: bool = True,
): ):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" """Broadcast inputs from src rank to all other ranks with torch.dist backend.
The `rank` here refer to the source rank on global process group (regardless
of dist_group argument).
"""
device = torch.device( device = torch.device(
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment