optimizer_from_gradient_accumulator.py 3.11 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from functools import cache
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union

import torch

from nanotron.optim.base import BaseOptimizer
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer
from nanotron.parallel.parameters import NanotronParameter


class OptimizerFromGradientAccumulator(InheritFromOtherOptimizer):
    def __init__(
        self,
        gradient_accumulator_builder: Callable[[Iterable[Tuple[str, NanotronParameter]]], GradientAccumulator],
        named_params_or_groups: Iterable[Union[Tuple[str, torch.Tensor], Dict[str, Any]]],
        optimizer_builder: Callable[[Iterable[Dict[str, Any]]], BaseOptimizer],
    ):
        named_param_groups = list(named_params_or_groups)
        if len(named_param_groups) == 0 or not isinstance(named_param_groups[0], dict):
            named_param_groups = [{"named_params": named_param_groups}]

        name_to_param = {}
        for named_param_group in named_param_groups:
            for name, param in named_param_group["named_params"]:
                if name in name_to_param:
                    raise ValueError(f"Duplicate key. {name} is already in `name_to_param`")
                else:
                    name_to_param[name] = param

        # Build gradient accumulator
        gradient_accumulator = gradient_accumulator_builder(name_to_param.items())
        self.gradient_accumulator = gradient_accumulator

        # Obtained new params depending on the gradient accumulator
        converted_named_param_group = [
            {
                **{k: v for k, v in named_param_group.items() if k != "named_params"},
                "named_params": [
                    (name, gradient_accumulator.get_parameter_for_optimizer(name))
                    for name, _ in named_param_group["named_params"]
                ],
            }
            for named_param_group in named_param_groups
        ]
        optimizer = optimizer_builder(converted_named_param_group)

        super().__init__(optimizer=optimizer, id_to_name=optimizer.id_to_name)

    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        loss = super().step(closure)
        self.gradient_accumulator.step()
        return loss

    def zero_grad(self):
        super().zero_grad()
        return self.gradient_accumulator.zero_grad()

    @cache
    def state_dict_additional_keys(self) -> Set[str]:
        return super().state_dict_additional_keys() | {"gradient_accumulator"}

    def state_dict(self) -> dict:
        state_dict = super().state_dict()

        assert "gradient_accumulator" not in state_dict
        state_dict["gradient_accumulator"] = self.gradient_accumulator.state_dict()
        return state_dict

    def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
        gradient_accumulator_state_dict = state_dict.pop("gradient_accumulator")
        super().load_state_dict(state_dict, map_location=map_location)
        self.gradient_accumulator.load_state_dict(gradient_accumulator_state_dict)