utils.py 609 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import tempfile
from contextlib import contextmanager, nullcontext
from typing import Iterator

import torch.distributed as dist


@contextmanager
def shared_tempdir() -> Iterator[str]:
    """
    A temporary directory that is shared across all processes.
    """
    ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
    with ctx_fn() as tempdir:
        try:
            obj = [tempdir]
            dist.broadcast_object_list(obj, src=0)
            tempdir = obj[0]    # use the same directory on all ranks
            yield tempdir
        finally:
            dist.barrier()