"INSTALL/vscode:/vscode.git/clone" did not exist on "d71514f23ee85c40e76b4b7391eff6fda2eca36f"
base.py 5.79 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import (
    Any,
    Callable,
    DefaultDict,
    Dict,
    Hashable,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
)

import torch
from typing_extensions import TypeAlias

Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]


class BaseOptimizer(ABC):
    id_to_name: Dict[int, str]
    param_groups: List[Dict[str, Any]]

    @abstractmethod
    def __getstate__(self):
        ...

    @abstractmethod
    def __setstate__(self, state):
        ...

    @abstractmethod
    def __repr__(self):
        ...

    @abstractmethod
    def zero_grad(self):
        ...

    @abstractmethod
    def state_dict_additional_keys(self) -> Set[str]:
        """Additional states we store in `state_dict`. It has to be a dictionary using parameter name as key, and a tensor as value."""
        ...

    @abstractmethod
    def state_dict(self) -> dict:
        ...

    @abstractmethod
    def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
        ...

    @abstractmethod
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        ...

    def inherit_from(self, cls) -> bool:
        ...


Optimizer = TypeVar("Optimizer", BaseOptimizer, torch.optim.Optimizer)


# Modified from torch.optim.Optimizer._process_value_according_to_param_policy
@staticmethod
def _process_value_according_to_param_policy(
    param: torch.Tensor,
    value: torch.Tensor,
    param_id: int,
    param_groups: List[Dict[Any, Any]],
    map_location: Optional[Union[str, torch.device]],
    key: Hashable = None,
) -> torch.Tensor:
    # If map_location is specified, use it instead of param.device
    target_device = map_location if map_location is not None else param.device

    fused = False
    capturable = False
    assert param_groups is not None
    for pg in param_groups:
        if param_id in pg["params"]:
            fused = pg["fused"] if "fused" in pg else False
            capturable = pg["capturable"] if "capturable" in pg else False
            break

    if key == "step":
        if capturable or fused:
            return value.to(dtype=torch.float32, device=target_device)
        else:
            return value
    else:
        if param.is_floating_point():
            return value.to(dtype=param.dtype, device=target_device)
        else:
            return value.to(device=target_device)


# Modified from torch.optim.Optimizer.load_state_dict
@torch._disable_dynamo
def custom_load_state_dict(self, state_dict: StateDict, map_location: Union[str, torch.device]) -> None:
    r"""Loads the optimizer state.

    Args:
        state_dict (dict): optimizer state. Should be an object returned
            from a call to :meth:`state_dict`.
        map_location (str or torch.device, optional): Device where to load the optimizer states.
            If None, states will be loaded to the same device as their corresponding parameters.
            Default: None
    """

    # shallow copy, to be consistent with module API
    state_dict = state_dict.copy()

    for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
        hook_result = pre_hook(self, state_dict)
        if hook_result is not None:
            state_dict = hook_result

    # Validate the state_dict
    groups = self.param_groups

    # Deepcopy as we write into saved_groups later to update state
    saved_groups = deepcopy(state_dict["param_groups"])

    if len(groups) != len(saved_groups):
        raise ValueError("loaded state dict has a different number of " "parameter groups")
    param_lens = (len(g["params"]) for g in groups)
    saved_lens = (len(g["params"]) for g in saved_groups)
    if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
        raise ValueError(
            "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
        )

    # Update the state
    id_map = dict(
        zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups))
    )

    def _cast(param, value, param_id=None, param_groups=None, key=None):
        r"""Make a deep copy of value, casting all tensors to device of param."""
        if isinstance(value, torch.Tensor):
            return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key)
        elif isinstance(value, dict):
            return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
        elif isinstance(value, Iterable):
            return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)
        else:
            return value

    # Copy state assigned to params (and cast tensors to appropriate types).
    # State that is not assigned to params is copied as is (needed for
    # backward compatibility).
    state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
    for k, v in state_dict["state"].items():
        if k in id_map:
            param = id_map[k]
            state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
        else:
            state[k] = v

    # Update parameter groups, setting their 'params' value
    def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
        new_group["params"] = group["params"]
        return new_group

    param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
    self.__setstate__({"state": state, "param_groups": param_groups})

    for post_hook in self._optimizer_load_state_dict_post_hooks.values():
        post_hook(self)