fp16_layers.py 2.71 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter


_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)


def conversion_helper(val, conversion):
    """Apply conversion to val. Recursively apply conversion if `val`
    #is a nested tuple/list structure."""
    if isinstance(val, dict):
        res_dict = {}
        for k, v in val.items():
            if k != "cos_cis_img" and k != "sin_cis_img":
                res_dict[k] = conversion_helper(v, conversion)
            else:
                res_dict[k] = v
        return res_dict
    if not isinstance(val, (tuple, list)):
        return conversion(val)
    rtn = [conversion_helper(v, conversion) for v in val]
    if isinstance(val, tuple):
        rtn = tuple(rtn)
    return rtn


def fp32_to_float16(val, float16_convertor):
    """Convert fp32 `val` to fp16/bf16"""

    def half_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
        if isinstance(val_typecheck, _FLOAT_TYPES):
            val = float16_convertor(val)
        return val

    return conversion_helper(val, half_conversion)


def float16_to_fp32(val):
    """Convert fp16/bf16 `val` to fp32"""

    def float_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
        if isinstance(val_typecheck, (_HALF_TYPES,)):
            val = val.float()
        return val

    return conversion_helper(val, float_conversion)


class Float16Module(torch.nn.Module):

    def __init__(self, module, args):
        super(Float16Module, self).__init__()

        self.add_module("module", module.half())

        def float16_convertor(val):
            return val.half()

        self.float16_convertor = float16_convertor

        self.config = self.module.config
        self.dtype = torch.float16

    def forward(self, *inputs, **kwargs):
        inputs = fp32_to_float16(inputs, self.float16_convertor)
        kwargs = fp32_to_float16(kwargs, self.float16_convertor)
        outputs = self.module(*inputs, **kwargs)
        outputs = float16_to_fp32(outputs)
        return outputs

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        return self.module.state_dict(destination, prefix, keep_vars)

    def state_dict_for_save_checkpoint(
        self, destination=None, prefix="", keep_vars=False
    ):
        return self.module.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars
        )

    def load_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)