replace_module.py 7.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
import copy
import torch
import deepspeed


def replace_transformer_layer(orig_layer_impl,
                              model,
                              micro_batch_size,
                              bert_config,
10
11
                              seed=-1,
                              preln=True,
12
                              fp16=True,
13
                              training=True,
14
15
16
17
18
19
20
21
22
23
24
25
                              huggingface=False,
                              local_rank=-1):
    """ Replace bert-style transformer layers with DeepSpeed's transformer layer
    Arguments:
        orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
            e.g., transformers.modeling_bert.BertLayer.
        model (torch.nn.Module): user's nn.module representing their model
        micro_batch_size (int): micro batch size per gpu used during training/eval
        bert_config (dict): model config containing hidden size, attention heads, etc.
        seed (int): random seed value
        preln (bool): does the original layer implementation do pre or post layer norm?
        fp16 (bool): fp16 or fp32
26
        Training (bool): select between training (True) or inference (False) mode
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes)

    Returns:
        Updated nn.module with replaced transformer layers
    """
    def replace_fn(child):
        transformer_config = deepspeed.DeepSpeedTransformerConfig(
            batch_size=micro_batch_size,
            hidden_size=bert_config.hidden_size,
            heads=bert_config.num_attention_heads,
            attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
            hidden_dropout_ratio=bert_config.hidden_dropout_prob,
            num_hidden_layers=bert_config.num_hidden_layers,
            initializer_range=bert_config.initializer_range,
            seed=seed,
            fp16=fp16,
            pre_layer_norm=preln,
            huggingface=huggingface,
45
46
            local_rank=local_rank,
            training=training)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)

        # copy relevant state from child -> new module
        qw = child.attention.self.query.weight
        qb = child.attention.self.query.bias
        kw = child.attention.self.key.weight
        kb = child.attention.self.key.bias
        vw = child.attention.self.value.weight
        vb = child.attention.self.value.bias

        qkvw = torch.cat((qw, kw, vw), 0)
        qkvb = torch.cat((qb, kb, vb), 0)

        #qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
        #qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)

        new_module.attn_qkvw.data = qkvw
        new_module.attn_qkvb.data = qkvb
        new_module.attn_ow.data = child.attention.output.dense.weight
        new_module.attn_ob.data = child.attention.output.dense.bias
        if preln:
            attention_layernorm = child.PostAttentionLayerNorm
        else:
            attention_layernorm = child.attention.output.LayerNorm
        new_module.attn_nw.data = attention_layernorm.weight
        new_module.attn_nb.data = attention_layernorm.bias
        if preln:
            intermediate_ff = child.intermediate.dense_act
        else:
            intermediate_ff = child.intermediate.dense
        new_module.inter_w.data = intermediate_ff.weight
        new_module.inter_b.data = intermediate_ff.bias
        new_module.output_w.data = child.output.dense.weight
        new_module.output_b.data = child.output.dense.bias
        if preln:
            transformer_layernorm = child.PreAttentionLayerNorm
        else:
            transformer_layernorm = child.output.LayerNorm
        new_module.norm_w.data = transformer_layernorm.weight
        new_module.norm_b.data = transformer_layernorm.bias
        return new_module

    return replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn)


def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False):
    """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer
    Arguments:
        orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
            e.g., transformers.modeling_bert.BertLayer.
        model (torch.nn.Module): user's nn.module representing their model
        bert_config (dict): model config containing hidden size, attention heads, etc.

    Returns:
        Updated nn.module with original bert-style transformer layers
    """
    def replace_fn(child):
        #from turing.nvidia_modelingpreln import BertLayer
        orig_module = orig_layer_impl(bert_config)

        # copy relevant state from child -> original module
        qkvw = child.attn_qkvw.data
        qkvb = child.attn_qkvb.data

        qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
        qb, kb, vb = torch.chunk(qkvb, 3, axis=0)

        orig_module.attention.self.query.weight.data = qw
        orig_module.attention.self.query.bias.data = qb
        orig_module.attention.self.key.weight.data = kw
        orig_module.attention.self.key.bias.data = kb
        orig_module.attention.self.value.weight.data = vw
        orig_module.attention.self.value.bias.data = vb

        orig_module.attention.output.dense.weight.data = child.attn_ow.data
        orig_module.attention.output.dense.bias.data = child.attn_ob.data

        attn_ln_w = child.attn_nw.data
        attn_ln_b = child.attn_nb.data
        if preln:
            orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
            orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
        else:
            orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
            orig_module.attention.output.LayerNorm.bias.data = attn_ln_b

        inter_ff_w = child.inter_w.data
        inter_ff_b = child.inter_b.data
        if preln:
            orig_module.intermediate.dense_act.weight.data = inter_ff_w
            orig_module.intermediate.dense_act.bias.data = inter_ff_b
        else:
            orig_module.intermediate.dense.weight.data = inter_ff_w
            orig_module.intermediate.dense.bias.data = inter_ff_b

        orig_module.output.dense.weight.data = child.output_w.data
        orig_module.output.dense.bias.data = child.output_b.data

        transformer_ln_w = child.norm_w.data
        transformer_ln_b = child.norm_b.data
        if preln:
            orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
            orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
        else:
            orig_module.output.LayerNorm.weight.data = transformer_ln_w
            orig_module.output.LayerNorm.bias.data = transformer_ln_b
        return orig_module

    return replace_module(model=model,
                          orig_class=deepspeed.DeepSpeedTransformerLayer,
                          replace_fn=replace_fn)


def replace_module(model, orig_class, replace_fn):
    """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
    Arguments:
        model (torch.nn.Module): the model to augment
        orig_class (torch.nn.Module): the module to search for
        replace_fn (method): a method to convert instances of ``orig_class`` to the
                             desired type and return a new instance.

    Returns:
        A modified ``model``.
    """
    policy = {orig_class: replace_fn}
    return _replace_module(model, policy)


def _replace_module(model, policies):
    """ Traverse model's children recursively and apply any transformations in ``policies``.
    Arguments:
        model (torch.nn.Module): model to augment
        policies (dict): Mapping of source class to replacement function.

    Returns:
        Modified ``model``.
    """
    for name, child in model.named_children():
        if child.__class__ in policies:
            orig = repr(child)
            setattr(model, name, policies[child.__class__](child))
            new = getattr(model, name)
        else:
            _replace_module(child, policies)

    return model