test_module.py 3.66 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
2
3
4
5
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import pytest
import torch

liangjing's avatar
liangjing committed
6
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
liangjing's avatar
v1  
liangjing committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from megatron.core.transformer.module import Float16Module, MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils

DEVICE_CAPABILITY = None
if torch.cuda.is_available():
    DEVICE_CAPABILITY = torch.cuda.get_device_capability()


class DummyModule(MegatronModule):
    # def __init__(self, config: TransformerConfig, share_embeddings_and_output_weights=True):
    def __init__(self, config: TransformerConfig):
        super().__init__(config)

        self.linear = torch.nn.modules.Linear(in_features=2, out_features=1)

    def forward(self, x):
        return self.linear(x)

liangjing's avatar
liangjing committed
26

liangjing's avatar
v1  
liangjing committed
27
28
29
class TestMegatronModule:

    def setup_method(self, method):
liangjing's avatar
liangjing committed
30
        Utils.initialize_model_parallel(1, 1)
liangjing's avatar
v1  
liangjing committed
31
        model_parallel_cuda_manual_seed(123)
liangjing's avatar
liangjing committed
32
33
34
        transformer_config = TransformerConfig(
            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
        )
liangjing's avatar
v1  
liangjing committed
35
36
37
        self.megatron_module = DummyModule(config=transformer_config).cuda()

    def teardown_method(self, method):
liangjing's avatar
liangjing committed
38
        Utils.destroy_model_parallel()
liangjing's avatar
v1  
liangjing committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

    def test_megatron_module(self):
        megatron_module = self.megatron_module
        assert megatron_module
        assert megatron_module.config.hidden_size == 12
        assert megatron_module.config.ffn_hidden_size == 48
        assert megatron_module.linear.weight.dtype == torch.float32

        x = torch.ones((2, 2)).cuda()
        assert megatron_module(x).dtype == torch.float32

        # TODO: test bad configs actually fail
        # failed_module = megatron_module
        # failed_module.fp16 = True
        # failed_module.bf16 = True


class TestFloat16Module:

    def setup_method(self, method):
liangjing's avatar
liangjing committed
59
        Utils.initialize_model_parallel(1, 1)
liangjing's avatar
v1  
liangjing committed
60
        model_parallel_cuda_manual_seed(123)
liangjing's avatar
liangjing committed
61
62
63
        self.transformer_config = TransformerConfig(
            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
        )
liangjing's avatar
v1  
liangjing committed
64
65
66
        self.megatron_module = DummyModule(config=self.transformer_config).cuda()

    def teardown_method(self, method):
liangjing's avatar
liangjing committed
67
68
        Utils.destroy_model_parallel()

liangjing's avatar
v1  
liangjing committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    def test_fp16_module(self):
        transformer_config = self.transformer_config
        megatron_module = self.megatron_module
        transformer_config.fp16 = True
        fp16_module = Float16Module(config=transformer_config, module=megatron_module)

        assert fp16_module
        assert fp16_module.config.hidden_size == 12
        assert fp16_module.config.ffn_hidden_size == 48
        assert fp16_module.module.linear.weight.dtype == torch.float16

        x = torch.ones((2, 2)).cuda()
        # inputs are converted to fp16 then outputs are converted to fp32
        assert fp16_module(x).dtype == torch.float32

    pytest.mark.skipif(
liangjing's avatar
liangjing committed
85
86
        not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8,
        reason='bfloat16 is not supported on this device',
liangjing's avatar
v1  
liangjing committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    )

    def test_bf16_module(self):
        transformer_config = self.transformer_config
        megatron_module = self.megatron_module
        transformer_config.bf16 = True
        bf16_module = Float16Module(config=transformer_config, module=megatron_module)

        assert bf16_module
        assert bf16_module.config.hidden_size == 12
        assert bf16_module.config.ffn_hidden_size == 48
        assert bf16_module.module.linear.weight.dtype == torch.bfloat16

        x = torch.ones((2, 2)).cuda()
        # inputs are converted to bf16 then outputs are converted to fp32
        assert bf16_module(x).dtype == torch.float32