test_initialization.py 3.96 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import pytest
import torch

import megatron.core.parallel_state as ps
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.tensor_parallel.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    VocabParallelEmbedding,
)
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils


class Test:

    transformer_config = TransformerConfig(
        num_layers=1, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
    )

    @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
    def test_embedding_init(self):

        Utils.initialize_model_parallel(1, 1)
        torch.manual_seed(42)
        model_parallel_cuda_manual_seed(42)

        tp1 = VocabParallelEmbedding(
            num_embeddings=16,
            embedding_dim=4,
            init_method=self.transformer_config.init_method,
            config=self.transformer_config,
        ).weight
        Utils.destroy_model_parallel()

        Utils.initialize_model_parallel(4, 1)
        torch.manual_seed(42)
        model_parallel_cuda_manual_seed(41)  # intentionally different.
        tp4 = VocabParallelEmbedding(
            num_embeddings=16,
            embedding_dim=4,
            init_method=self.transformer_config.init_method,
            config=self.transformer_config,
        ).weight

        rank = ps.get_tensor_model_parallel_rank()
        assert tp4.shape[0] * 4 == tp1.shape[0]
        assert torch.equal(tp1[rank * 4 : (rank + 1) * 4], tp4)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
    def test_row_init(self):

        Utils.initialize_model_parallel(1, 1)
        torch.manual_seed(42)
        model_parallel_cuda_manual_seed(42)

        tp1 = RowParallelLinear(
            input_size=16,
            output_size=16,
            init_method=self.transformer_config.init_method,
            bias=True,
            input_is_parallel=False,
            config=self.transformer_config,
            skip_bias_add=False,
        ).weight
        Utils.destroy_model_parallel()

        Utils.initialize_model_parallel(4, 1)
        torch.manual_seed(42)
        model_parallel_cuda_manual_seed(41)  # intentionally different.
        tp4 = RowParallelLinear(
            input_size=16,
            output_size=16,
            init_method=self.transformer_config.init_method,
            bias=True,
            input_is_parallel=False,
            config=self.transformer_config,
            skip_bias_add=False,
        ).weight

        rank = ps.get_tensor_model_parallel_rank()
        assert tp4.shape[1] * 4 == tp1.shape[1]
        assert torch.equal(tp1[:, rank * 4 : (rank + 1) * 4], tp4)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
    def test_col_init(self):

        Utils.initialize_model_parallel(1, 1)
        torch.manual_seed(42)
        model_parallel_cuda_manual_seed(42)

        tp1 = ColumnParallelLinear(
            input_size=16,
            output_size=16,
            init_method=self.transformer_config.init_method,
            bias=True,
            config=self.transformer_config,
            skip_bias_add=False,
        ).weight
        Utils.destroy_model_parallel()

        Utils.initialize_model_parallel(4, 1)
        torch.manual_seed(42)
        model_parallel_cuda_manual_seed(41)  # intentionally different.
        tp4 = ColumnParallelLinear(
            input_size=16,
            output_size=16,
            init_method=self.transformer_config.init_method,
            bias=True,
            config=self.transformer_config,
            skip_bias_add=False,
        ).weight

        rank = ps.get_tensor_model_parallel_rank()
        assert tp4.shape[0] * 4 == tp1.shape[0]
        assert torch.equal(tp1[rank * 4 : (rank + 1) * 4], tp4)