utils.py 3.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Training utilities for Coati.
"""
from typing import Any

import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader


class CycledDataLoader:
    """
    A data loader that cycles through the data when it reaches the end.

    Args:
        dataloader (DataLoader): The original data loader.

    Attributes:
        dataloader (DataLoader): The original data loader.
        count (int): The number of times the data loader has been cycled.
        dataloader_iter (iterable): The iterator for the data loader.

    Methods:
        next(): Returns the next batch of data from the data loader, cycling through the data if necessary.
    """

    def __init__(
        self,
        dataloader: DataLoader,
    ) -> None:
        self.dataloader = dataloader

        self.count = 0
        self.dataloader_iter = None

    def next(self):
        """
        Returns the next batch of data from the data loader, cycling through the data if necessary.

        Returns:
            Any: The next batch of data from the data loader.
        """
        # defer initialization
        if self.dataloader_iter is None:
            self.dataloader_iter = iter(self.dataloader)

        self.count += 1
        try:
            return next(self.dataloader_iter)
        except StopIteration:
            self.count = 0
            self.dataloader_iter = iter(self.dataloader)
            return next(self.dataloader_iter)


def is_rank_0() -> bool:
    """
    Check if the current process is the rank 0 process in a distributed training setup.

    Returns:
        bool: True if the current process is the rank 0 process, False otherwise.
    """
    return not dist.is_initialized() or dist.get_rank() == 0


def to_device(x: Any, device: torch.device) -> Any:
    """
    Move the input tensor or nested structure of tensors to the specified device.

    Args:
        x (Any): The input tensor or nested structure of tensors.
        device (torch.device): The target device to move the tensors to.

    Returns:
        Any: The tensor or nested structure of tensors moved to the target device.
    """

    def _to(t: Any):
        if isinstance(t, torch.Tensor):
            return t.to(device)
        return t

    return tree_map(_to, x)


def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
    """
    Perform all-reduce operation on the given tensor and compute the mean across all processes.

    Args:
        tensor (torch.Tensor): The input tensor to be reduced.

    Returns:
        torch.Tensor: The reduced tensor with mean computed across all processes.
    """
    dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
    tensor.div_(dist.get_world_size())
    return tensor


def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
    """
    Performs an all-reduce operation to sum the values of the given tensor across all processes.

    Args:
        tensor (torch.Tensor): The input tensor to be reduced.

    Returns:
        torch.Tensor: The reduced tensor with the sum of values across all processes.
    """
    dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
    return tensor