context.py 1.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.distributed as dist

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc


class barrier_context():
    """
    This context manager is used to allow one process to execute while blocking all
    other processes in the same process group. This is often useful when downloading is required
    as we only want to download in one process to prevent file corruption.
    Args:
        executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
        parallel_mode (ParallelMode): the parallel mode corresponding to a process group
    Usage:
        with barrier_context():
            dataset = CIFAR10(root='./data', download=True)
    """

    def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
        # the class name is lowercase by convention
        current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
        self.should_block = current_rank != executor_rank
        self.group = gpc.get_group(parallel_mode=parallel_mode)

    def __enter__(self):
        if self.should_block:
            dist.barrier(group=self.group)

    def __exit__(self, exc_type, exc_value, exc_traceback):
        if not self.should_block:
            dist.barrier(group=self.group)