params.py 2.54 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
import collections.abc as abc
7
from dataclasses import dataclass
8
9
from math import inf
from typing import Any, Callable, Dict, List, Optional
10
11

import torch
12
import torch.distributed as dist
13
14


15
@dataclass
16
class Workhandle:
17
18
    handle: Any
    callback: Optional[Callable] = None
19
20


21
22
23
24
25
26
27
def get_global_rank(group: Any, rank: int) -> int:
    if group is dist.group.WORLD:
        return rank

    return dist.distributed_c10d._get_global_rank(group, rank)


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Credits:  classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
    """
    Recursively searches lists, tuples, dicts and copies tensors to device if
    possible. Non-tensor values are passed as-is in the result.

    NOTE:  These are all copies, so if there are two objects that reference
    the same object, then after this call, there will be two different objects
    referenced on the device.
    """

    if isinstance(value, torch.Tensor):
        return value.to(device, non_blocking=non_blocking)

    if isinstance(value, (list, tuple)):
        values = []
        for val in value:
            values.append(recursive_copy_to_device(val, non_blocking=non_blocking, device=device))

        return values if isinstance(value, list) else tuple(values)

49
    if isinstance(value, abc.Mapping):
50
51
52
53
54
55
56
        device_val: Dict[str, Any] = {}
        for key, val in value.items():
            device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device)

        return device_val

    return value
57
58


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:
    r"""Calculate gradient norm of an iterable of parameters.
    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda par: par.grad is not None, parameters))

    if len(parameters) == 0:
        return torch.tensor(0.0)
    p = float(p)
    if p == inf:
        local_norm = max(par.grad.detach().abs().max() for par in parameters)  # type: ignore
    else:
74
75
        # Compute the norm in full precision no matter what
        local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p, dtype=torch.float32) for par in parameters]), p).to(dtype=parameters[0].dtype)  # type: ignore
76
    return local_norm