grad_scaler.py 5.04 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
mohammad's avatar
mohammad committed
2
3
4

"""Megatron grad scaler."""

xingjinliang's avatar
xingjinliang committed
5
6
from abc import ABC, abstractmethod
from typing import Dict
mohammad's avatar
mohammad committed
7
8
9
10
11

import torch


class MegatronGradScaler(ABC):
xingjinliang's avatar
xingjinliang committed
12
    def __init__(self, initial_scale: float):
mohammad's avatar
mohammad committed
13
14
        """Initialize scale value with the input initial scale."""
        assert initial_scale > 0.0
xingjinliang's avatar
xingjinliang committed
15
        self._scale = torch.tensor([initial_scale], dtype=torch.float, device='cuda')
mohammad's avatar
mohammad committed
16
17
18
19
20
21
22
23
24
25

    @property
    def scale(self):
        return self._scale

    @property
    def inv_scale(self):
        return self._scale.double().reciprocal().float()

    @abstractmethod
xingjinliang's avatar
xingjinliang committed
26
    def update(self, found_inf: bool):
mohammad's avatar
mohammad committed
27
28
29
30
31
32
33
        pass

    @abstractmethod
    def state_dict(self):
        pass

    @abstractmethod
xingjinliang's avatar
xingjinliang committed
34
    def load_state_dict(self, state_dict: Dict):
mohammad's avatar
mohammad committed
35
        pass
mohammad's avatar
mohammad committed
36

mohammad's avatar
mohammad committed
37
38

class ConstantGradScaler(MegatronGradScaler):
xingjinliang's avatar
xingjinliang committed
39
40
41
    """
    Constant grad scaler (loss scale is never adjusted regardless of NaNs seen in gradients).
    """
mohammad's avatar
mohammad committed
42

xingjinliang's avatar
xingjinliang committed
43
    def update(self, found_inf: bool):
mohammad's avatar
mohammad committed
44
45
        pass

mohammad's avatar
mohammad committed
46
47
48
49
50
51
52
    def state_dict(self):
        return dict()

    def load_state_dict(self, state_dict):
        pass


mohammad's avatar
mohammad committed
53
class DynamicGradScaler(MegatronGradScaler):
xingjinliang's avatar
xingjinliang committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    """
    Grad scaler with dynamic scale that gets adjusted during training.

    Reduces loss scale by `backoff_factor` if `hysteresis` number of NaNs are seen in a row. Increases
    loss scale by `growth_factor` if NaNs are not seen for `growth_interval` iterations.
    """

    def __init__(
        self,
        initial_scale: float,
        min_scale: float,
        growth_factor: float,
        backoff_factor: float,
        growth_interval: int,
        hysteresis: int,
    ):
        """
        Grad scaler with dynamic scale that gets adjusted during training.

        Args:
            initial_scale (float): Initial loss scale value.
            min_scale (float): Minimum loss scale value.
            growth_factor (float): Factor to grow loss scale by if NaNs are not seen in `growth_interval`
                training iterations. Must be greater than 1.
            backoff_factor (float): Factor to decrease loss scale by if NaNs are seen in `hysteresis`
                consecutive training iterations. Must be between 0 and 1.
            growth_interval (int): Number of training iterations of no NaNs before loss scale is increased.
            hysteresis (int): Number of training iterations of consecutive NaNs before loss scale is decreased.
        """
mohammad's avatar
mohammad committed
83
84
85
86
87
        super(DynamicGradScaler, self).__init__(initial_scale)

        # Lower bound on the scale.
        assert min_scale > 0.0
        assert min_scale <= initial_scale
xingjinliang's avatar
xingjinliang committed
88
        self.min_scale = torch.tensor([min_scale], dtype=torch.float, device='cuda')
mohammad's avatar
mohammad committed
89
90
        # Growth and backoff factors for the scale.
        assert growth_factor > 1.0
xingjinliang's avatar
xingjinliang committed
91
        self.growth_factor = torch.tensor([growth_factor], dtype=torch.float, device='cuda')
mohammad's avatar
mohammad committed
92
93
        assert backoff_factor < 1.0
        assert backoff_factor > 0.0
xingjinliang's avatar
xingjinliang committed
94
        self.backoff_factor = torch.tensor([backoff_factor], dtype=torch.float, device='cuda')
mohammad's avatar
mohammad committed
95
96
97
98
99
100
101
102
103
104
105
106
107
        # Interval over which if we don't see any inf/nan,
        # we will scale the grad scale by the growth factor.
        assert growth_interval > 0
        self.growth_interval = growth_interval
        # Number of inf/nans we should see before scaling down
        # the grad scale by the backoff factor.
        assert hysteresis > 0
        self.hysteresis = hysteresis

        # Trackers.
        self._growth_tracker = 0
        self._hysteresis_tracker = self.hysteresis

xingjinliang's avatar
xingjinliang committed
108
109
110
111
    def update(self, found_inf: bool):
        """
        Updates internal state in grad scaler based on whether NaNs are seen in grads or not.
        """
mohammad's avatar
mohammad committed
112
113
114
115
116
117

        # If we have an inf/nan, growth tracker is set to 0
        # and hysterisis tracker is reduced by 1.
        if found_inf:
            self._growth_tracker = 0
            self._hysteresis_tracker -= 1
mohammad's avatar
mohammad committed
118
            # Now if we are out of hysteresis count, scale down the loss.
mohammad's avatar
mohammad committed
119
            if self._hysteresis_tracker <= 0:
xingjinliang's avatar
xingjinliang committed
120
                self._scale = torch.max(self._scale * self.backoff_factor, self.min_scale)
mohammad's avatar
mohammad committed
121
122
123
124
125
126
127
128
129
130
        else:
            # If there is no nan/inf, increment the growth tracker.
            self._growth_tracker += 1
            # If we have had enough consequitive intervals with no nan/inf:
            if self._growth_tracker == self.growth_interval:
                # Reset the tracker and hysteresis trackers,
                self._growth_tracker = 0
                self._hysteresis_tracker = self.hysteresis
                # and scale up the loss scale.
                self._scale = self._scale * self.growth_factor
mohammad's avatar
mohammad committed
131
132
133
134
135
136
137
138

    def state_dict(self):
        state_dict = {}
        state_dict['scale'] = self._scale
        state_dict['growth_tracker'] = self._growth_tracker
        state_dict['hysteresis_tracker'] = self._hysteresis_tracker
        return state_dict

xingjinliang's avatar
xingjinliang committed
139
    def load_state_dict(self, state_dict: Dict):
mohammad's avatar
mohammad committed
140
141
142
        self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
        self._growth_tracker = state_dict['growth_tracker']
        self._hysteresis_tracker = state_dict['hysteresis_tracker']