utils.py 393 Bytes
Newer Older
Hongxin Liu's avatar
Hongxin Liu committed
1
2
from typing import Any

3
import torch
Hongxin Liu's avatar
Hongxin Liu committed
4
5
import torch.distributed as dist
from torch.utils._pytree import tree_map
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
6
7
8
9


def is_rank_0() -> bool:
    return not dist.is_initialized() or dist.get_rank() == 0
10
11


Hongxin Liu's avatar
Hongxin Liu committed
12
13
14
15
16
17
18
19
def to_device(x: Any, device: torch.device) -> Any:

    def _to(t: Any):
        if isinstance(t, torch.Tensor):
            return t.to(device)
        return t

    return tree_map(_to, x)