named_optimizer.py 3.61 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch

from nanotron import logging
from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer

logger = logging.get_logger(__name__)


class NamedOptimizer(InheritFromOtherOptimizer):
    """Mimics somewhat the torch optimizer API"""

    def __init__(
        self,
        named_params_or_groups: Iterable[Union[Tuple[str, torch.Tensor], Dict[str, Any]]],
        optimizer_builder: Callable[[Iterable[Dict[str, Any]]], torch.optim.Optimizer],
    ):
        named_param_groups = list(named_params_or_groups)
        if len(named_param_groups) == 0 or not isinstance(named_param_groups[0], dict):
            named_param_groups = [{"named_params": named_param_groups}]

        id_to_name = {}
        params = []
        for named_param_group in named_param_groups:
            assert "named_params" in named_param_group
            # Don't need to check that param_groups are overlapping since the optimizer will do it for me.
            #  https://github.com/pytorch/pytorch/blob/88b3810c94b45f5982df616e2bc4c471d173f491/torch/optim/optimizer.py#L473
            id_to_name.update(
                {id(param): name for name, param in named_param_group["named_params"] if id(param) not in id_to_name}
            )
            params.append(
                {
                    **{k: v for k, v in named_param_group.items() if k != "named_params"},
                    "params": [param for _, param in named_param_group["named_params"]],
                }
            )

        name_to_id = {v: k for k, v in id_to_name.items()}
        assert len(id_to_name) == len(name_to_id)

        # Sanity check
        for param_group in params:
            _params = param_group["params"]
            for param in _params:
                # https://github.com/pytorch/pytorch/issues/100701
                assert param.numel() > 0
        super().__init__(optimizer=optimizer_builder(params), id_to_name=id_to_name)

    def state_dict(self) -> dict:
        optim_state_dict = super().state_dict()

        assert "names" not in optim_state_dict

        state_id_to_name = {id(state): self.id_to_name[id(param)] for param, state in self.optimizer.state.items()}
        optim_state_dict["names"] = {
            index: state_id_to_name[id(state)] for index, state in optim_state_dict["state"].items()
        }
        return optim_state_dict

    def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
        assert set(self.id_to_name.values()) == set(
            state_dict["names"].values()
        ), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}"
        assert len(state_dict["state"]) == len(
            state_dict["names"]
        ), f"Number of params in loaded state dict ({len(state_dict['state'])}) doesn't match number of names ({len(state_dict['names'])})"
        assert len(state_dict["state"]) > 0, "Loading empty state dict"
        OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"})
        for key in OPTIMIZER_STATE_KEYS:
            for k, state in state_dict["state"].items():
                assert (
                    key in state
                ), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}"

        return super().load_state_dict(state_dict, map_location=map_location)