test_spec_customization.py 9.79 KB
Newer Older
silencealiang's avatar
silencealiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import sys
from dataclasses import dataclass, fields

import pytest
import torch
import transformer_engine as te

from megatron.core.extensions.transformer_engine import (
    TEDotProductAttention,
    TELayerNormColumnParallelLinear,
    TENorm,
    TERowParallelLinear,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec, build_module, import_module
from megatron.core.transformer.transformer_block import TransformerBlock, TransformerBlockSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.utils import is_te_min_version
from tests.unit_tests.test_utilities import Utils


class TestSpecCustomization:
    def setup_method(self, method):
        Utils.initialize_model_parallel(1, 1)
        model_parallel_cuda_manual_seed(123)
        self.config = TransformerConfig(
            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
        )

        # specify Transformer Layer spec with all identity ops
        self.transformer_layer_spec = TransformerLayerSubmodules()

        # specify attention spec using already imported class
        self.attention_spec = ModuleSpec(
            module=SelfAttention,
            params={"attn_mask_type": AttnMaskType.causal},
            submodules=SelfAttentionSubmodules(
                linear_qkv=TELayerNormColumnParallelLinear,
                core_attention=TEDotProductAttention,
                linear_proj=TERowParallelLinear,
                q_layernorm=IdentityOp,
                k_layernorm=IdentityOp,
            ),
        )

        # specify layernorm spec with module path to test dynamic importing
        self.layernorm_spec = ModuleSpec(
            module=("megatron.core.extensions.transformer_engine", "TENorm")
        )

        # specify bias dropout add with module path
        self.bda_spec = ModuleSpec(
            module=("megatron.core.fusions.fused_bias_dropout", "get_bias_dropout_add")
        )

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    def test_import_module(self):
        self_attention_cls = import_module(
            module_path=('megatron.core.transformer.attention', 'SelfAttention')
        )
        assert id(self_attention_cls) == id(SelfAttention)

        layernorm_cls = import_module(module_path=self.layernorm_spec.module)
        assert id(layernorm_cls) == id(TENorm)

    def test_build_module(self):
        # Check NoOp TransformerLayer
        random_input = 12
        noop_transformer_layer = [
            build_module(getattr(self.transformer_layer_spec, field.name))
            for field in fields(self.transformer_layer_spec)
            if field.name != 'sharded_state_dict_keys_map'
        ]

        x = random_input
        for mod in noop_transformer_layer:
            # checking for `IdentityFuncOp` before `IdentityOp` because former
            # is derived from the latter and so the second if statement will
            # always be `True`.
            if isinstance(mod, IdentityFuncOp):
                x = mod()(x)
            elif isinstance(mod, IdentityOp):
                x = mod(x)

        assert x == random_input

        # Check SelfAttention
        self_attention = build_module(self.attention_spec, config=self.config, layer_number=1)
        assert isinstance(self_attention, SelfAttention)
        assert self_attention.layer_number == 1
        assert self_attention.attn_mask_type == self.attention_spec.params['attn_mask_type']

        num_weights = sum([p.numel() for p in self_attention.parameters()])
        assert num_weights == 648

        # Check SelfAttention but with already initialized module
        # `self_attention`. In this test, `build_module` acts as a no op as it
        # simply returns the initialized module.
        # NOTE: (sudhakars) Uncomment this test once this feature gets added
        # back.
        # self_attention2 = build_module(
        #     self_attention, config=self.config, spec=self.attention_spec,
        # )
        # assert isinstance(self_attention2, SelfAttention)
        # assert self_attention2.layer_number == 1
        # assert self_attention2.attn_mask_type == self.attention_spec.params['attn_mask_type']

        # num_weights = sum([p.numel() for p in self_attention2.parameters()])
        # assert num_weights == 648

        # Check LayerNorm
        layernorm = build_module(
            self.layernorm_spec,
            config=self.config,
            hidden_size=self.config.hidden_size,
            eps=self.config.layernorm_epsilon,
        )
        assert isinstance(layernorm, te.pytorch.LayerNorm)

        # Check BiasDropoutAdd
        bda_op = build_module(self.bda_spec)
        assert id(bda_op) == id(get_bias_dropout_add)

    def test_sliding_window_attention(self):
        if not is_te_min_version("1.2.0"):
            print("SWA not tested because TE version is not >= 1.2.0", file=sys.stderr)
            return

        config = TransformerConfig(
            num_layers=2,
            hidden_size=12,
            num_attention_heads=4,
            use_cpu_initialization=True,
            window_size=[10, 0],
        )
        # Make sure DotProductAttention throws (swa unsupported).
        threw = False
        try:
            attn = DotProductAttention(
                config, layer_number=1, attn_mask_type=AttnMaskType.causal, attention_type='self'
            )
        except:
            threw = True
        finally:
            assert threw, 'Expected DotProductAttention to throw exception for SWA'

        # Test TEDotProductAttention
        attn = TEDotProductAttention(
            config, layer_number=1, attn_mask_type=AttnMaskType.causal, attention_type='self'
        )
        # Make sure window-size is what we expect.
        assert attn.window_size == config.window_size

        # Single integer window-size unsupported, make sure it throws
        threw = False
        try:
            config.window_size = 11
            attn = TEDotProductAttention(
                config, layer_number=1, attn_mask_type=AttnMaskType.causal, attention_type='self'
            )
        except:
            threw = True
        finally:
            assert threw, "Expected TEDotProductAttention to throw for integer window-size"

        # `None` makes this causal.
        config.window_size = None
        attn = TEDotProductAttention(
            config, layer_number=1, attn_mask_type=AttnMaskType.causal, attention_type='self'
        )
        # Make sure it's causal.
        assert attn.window_size == (-1, 0)

    def test_transformer_block_custom(self):
        """
        This test checks that the two ways of passing `layer_spec` to  a
        `TransformerBlock` result in an identical model:
        1. ModuleSpec(module=..., submodules=...)
        2. TransformerBlockSubmodules(layer_specs=...)
        """

        transformer_config = TransformerConfig(
            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
        )
        layer_local_spec = get_gpt_layer_local_spec()

        # The following way can be used to pass a different `TransformerLayer`
        # and internally the `TransformerBlock` would fan out the single
        # `ModuleSpec` layer spec provided to all the layers of the block.
        layer_spec1 = ModuleSpec(module=TransformerLayer, submodules=layer_local_spec.submodules)
        model_parallel_cuda_manual_seed(123)
        torch.manual_seed(0)
        parallel_transformer_block1 = TransformerBlock(transformer_config, layer_spec1)

        layer_spec2 = TransformerBlockSubmodules(
            layer_specs=[
                ModuleSpec(module=TransformerLayer, submodules=layer_local_spec.submodules)
            ]
            * transformer_config.num_layers,
            layer_norm=TENorm,
        )
        # make sure the model init conditions are identical
        model_parallel_cuda_manual_seed(123)
        torch.manual_seed(0)
        parallel_transformer_block2 = TransformerBlock(transformer_config, layer_spec2)

        sequence_length = 32
        micro_batch_size = 2
        parallel_transformer_block1.cuda()
        parallel_transformer_block2.cuda()

        # [sequence length, batch size, hidden size]
        hidden_states = torch.ones(
            (sequence_length, micro_batch_size, transformer_config.hidden_size)
        )
        hidden_states = hidden_states.cuda()

        attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()

        out1 = parallel_transformer_block1(
            hidden_states=hidden_states, attention_mask=attention_mask
        )
        out2 = parallel_transformer_block2(
            hidden_states=hidden_states, attention_mask=attention_mask
        )

        assert torch.all(torch.eq(out1, out2))
        assert out1.shape[0] == sequence_length == out2.shape[0]
        assert out1.shape[1] == micro_batch_size == out2.shape[1]
        assert out1.shape[2] == transformer_config.hidden_size == out2.shape[2]