Unverified Commit d1c3d7d1 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[misc][distributed] fix benign error in `is_in_the_same_node` (#5512)

parent 77490c6f
......@@ -23,8 +23,9 @@ import contextlib
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import resource_tracker, shared_memory
from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
from torch.distributed import Backend, ProcessGroup
......@@ -744,6 +745,11 @@ def is_in_the_same_node(pg: ProcessGroup):
src=ranks[0],
group=pg)
name = recv[0]
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
shm = shared_memory.SharedMemory(name=name)
if shm.buf[:len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1
......@@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup):
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == 0:
if shm:
if rank == 0 and shm:
shm.unlink()
else:
if shm:
# fix to https://stackoverflow.com/q/62748654/9191338
resource_tracker.unregister(
shm._name, "shared_memory") # type: ignore[attr-defined]
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return is_in_the_same_node.sum().item() == world_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment