apex_amp.py 1.05 KB
Newer Older
Frank Lee's avatar
Frank Lee committed
1
2
3
4
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch.nn as nn
5

Frank Lee's avatar
Frank Lee committed
6
7
try:
    import apex.amp as apex_amp
8
except ImportError:
Frank Lee's avatar
Frank Lee committed
9
    pass
10

Frank Lee's avatar
Frank Lee committed
11
12
13
14
15
16
17
from torch import Tensor

from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32


class ApexAMPOptimizer(ColossalaiOptimizer):
HELSON's avatar
HELSON committed
18
    """ A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
19
    methods
HELSON's avatar
HELSON committed
20
    """
Frank Lee's avatar
Frank Lee committed
21
22

    def backward(self, loss: Tensor):
HELSON's avatar
HELSON committed
23
24
        """Backward pass to get all gradients

25
26
        Args:
            loss (torch.Tensor): Loss computed by a loss function
27
        """
Frank Lee's avatar
Frank Lee committed
28
29
30
31
        with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
            scaled_loss.backward()

    def clip_grad_norm(self, model: nn.Module, max_norm: float):
32
        """Clip gradients by norm
HELSON's avatar
HELSON committed
33

34
35
36
        Args:
            model (torch.nn.Module): Your model object
            max_norm (float): The max norm value for gradient clipping
37
        """
Frank Lee's avatar
Frank Lee committed
38
39
        if max_norm > 0:
            clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)