test_parallel_state.py 7.33 KB
Newer Older
shanmugamr's avatar
shanmugamr committed
1
2
3
4
5
6
import os
import torch
import megatron.core.parallel_state as ps
from datetime import timedelta
import pytest

shanmugamr's avatar
shanmugamr committed
7
8
9
10

world_size = torch.cuda.device_count()
rank = int(os.environ['LOCAL_RANK'])
print('Ranks is : ' + str(rank))
shanmugamr's avatar
shanmugamr committed
11
12
13
14
15
16
17
18
19
20

def initialize_distributed():
    print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
    torch.cuda.set_device(rank % torch.cuda.device_count())
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port
    torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method, timeout=timedelta(seconds=10))

shanmugamr's avatar
shanmugamr committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size = None,
    pipeline_model_parallel_split_rank = None,
):
    # This might not be the right way to do this. 
    try:
        ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank)
    except:
        ps.destroy_model_parallel() 
        ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank)
        pass

shanmugamr's avatar
shanmugamr committed
35
36
37
38
39
def test_initialize_model_parallel():
    with pytest.raises(AssertionError):
        assert(ps.initialize_model_parallel())
    initialize_distributed()
    with pytest.raises(RuntimeError):
shanmugamr's avatar
shanmugamr committed
40
41
42
43
44
        assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))
    with pytest.raises(RuntimeError):
        assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))
    with pytest.raises(RuntimeError):
        assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))
shanmugamr's avatar
shanmugamr committed
45
    with pytest.raises(RuntimeError):
shanmugamr's avatar
shanmugamr committed
46
        assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))
shanmugamr's avatar
shanmugamr committed
47
    initialize_model_parallel()
shanmugamr's avatar
shanmugamr committed
48
49
50
51
52
53
54
55

    assert(ps.model_parallel_is_initialized())
    assert(ps.get_model_parallel_group() is not None)
    assert(ps.get_tensor_model_parallel_group() is not None)
    assert(ps.get_pipeline_model_parallel_group() is not None)
    assert(ps.get_data_parallel_group() is not None)  
    assert(ps.get_embedding_group() is not None)  
    assert(ps.get_position_embedding_group() is not None)
shanmugamr's avatar
shanmugamr committed
56
57
58
    ps.destroy_model_parallel()

def test_pipeline_parallel_initializations():
shanmugamr's avatar
shanmugamr committed
59
    initialize_model_parallel(pipeline_model_parallel_size=2)
shanmugamr's avatar
shanmugamr committed
60
    assert(ps.get_pipeline_model_parallel_first_rank() == 0)
shanmugamr's avatar
shanmugamr committed
61
62
    assert(ps.get_data_parallel_src_rank() == rank)
    assert(ps.get_pipeline_model_parallel_next_rank() == 0 if rank == world_size - 1 else rank + 1)
shanmugamr's avatar
shanmugamr committed
63
    assert(ps.get_pipeline_model_parallel_prev_rank() == rank - 1 if rank > 0 else world_size - 1)
shanmugamr's avatar
shanmugamr committed
64
    ps.destroy_model_parallel()
shanmugamr's avatar
shanmugamr committed
65
 
shanmugamr's avatar
shanmugamr committed
66
def test_data_parallel_initializations():
shanmugamr's avatar
shanmugamr committed
67
    initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
68
69
70
71
72
    assert(ps.get_data_parallel_src_rank() == rank)
    assert(ps.get_data_parallel_world_size() == world_size-1)
    assert(ps.get_data_parallel_rank() == 0)
    ps.destroy_model_parallel() 
    
shanmugamr's avatar
shanmugamr committed
73
def test_tensor_model_parellel_world_size():
shanmugamr's avatar
shanmugamr committed
74
    initialize_model_parallel(tensor_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
75
76
77
    assert(ps.get_tensor_model_parallel_world_size() == world_size)
    ps.set_tensor_model_parallel_world_size(None)
    assert(ps.get_tensor_model_parallel_world_size() == world_size)
shanmugamr's avatar
shanmugamr committed
78
79
    ps.destroy_model_parallel()

shanmugamr's avatar
shanmugamr committed
80
81

def test_pipeline_model_parallel_world_size():
shanmugamr's avatar
shanmugamr committed
82
    initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
83
84
85
    assert(ps.get_pipeline_model_parallel_world_size() == world_size)
    ps.set_pipeline_model_parallel_world_size(None)
    assert(ps.get_pipeline_model_parallel_world_size() == world_size)
shanmugamr's avatar
shanmugamr committed
86
87
    ps.destroy_model_parallel()

shanmugamr's avatar
shanmugamr committed
88
89

def test_tensor_model_parallel_rank():
shanmugamr's avatar
shanmugamr committed
90
    initialize_model_parallel(tensor_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
91
92
93
    assert(ps.get_tensor_model_parallel_rank() == rank)
    ps.set_tensor_model_parallel_rank(None)
    assert(ps.get_tensor_model_parallel_rank() == rank)    
shanmugamr's avatar
shanmugamr committed
94
    ps.destroy_model_parallel()
shanmugamr's avatar
shanmugamr committed
95

shanmugamr's avatar
shanmugamr committed
96
def test_pipeline_model_parallel_rank():
shanmugamr's avatar
shanmugamr committed
97
    initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
98
99
100
    assert(ps.get_pipeline_model_parallel_rank() == rank)
    ps.set_pipeline_model_parallel_rank(None)
    assert(ps.get_pipeline_model_parallel_rank() == rank)
shanmugamr's avatar
shanmugamr committed
101
102
    ps.destroy_model_parallel()
    
shanmugamr's avatar
shanmugamr committed
103
def test_is_pipeline_first_stage():
shanmugamr's avatar
shanmugamr committed
104
    initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
105
106
107
    assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0))
    assert(ps.is_pipeline_first_stage() == (rank == 0))
    ps.destroy_model_parallel()
shanmugamr's avatar
shanmugamr committed
108
109

def test_is_pipeline_last_stage():
shanmugamr's avatar
shanmugamr committed
110
    initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
111
112
113
114
115
116
    assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1))
    assert(ps.is_pipeline_last_stage() == (rank == world_size-1))
    ps.destroy_model_parallel()


def test_virtual_pipeline_model_parallel_rank():
shanmugamr's avatar
shanmugamr committed
117
    initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
118
119
120
121
122
    ps.set_virtual_pipeline_model_parallel_rank(rank)
    assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
    ps.destroy_model_parallel()

def test_get_tensor_model_parallel_src_rank():
shanmugamr's avatar
shanmugamr committed
123
    initialize_model_parallel(tensor_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
124
125
126
127
128
129
130
131
132
133
134
135
    assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
    ps.destroy_model_parallel()

def test_global_memory_buffer():
    ps._GLOBAL_MEMORY_BUFFER = None
    ps._set_global_memory_buffer()
    assert(ps.get_global_memory_buffer() is not None)


"""

def test_get_virtual_pipeline_model_parallel_world_size():
shanmugamr's avatar
shanmugamr committed
136
    initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
137
138
139
140
141
    ps.set_virtual_pipeline_model_parallel_rank(world_size)
    assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size)
    ps.destroy_model_parallel()


shanmugamr's avatar
shanmugamr committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

def test_is_rank_in_embedding_group():
    assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS))
    if rank in ps._EMBEDDING_GLOBAL_RANKS:
        assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_first_stage())
    elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
        assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_last_stage())
    else:
        assert(ps.is_rank_in_embedding_group())

def test_is_rank_in_position_embedding_group():
    assert(ps.is_rank_in_position_embedding_group() == (rank in ps._POSITION_EMBEDDING_GLOBAL_RANKS))

def test_is_pipeline_stage_before_split():
    if world_size == 1:
        assert(ps.is_pipeline_stage_before_split())
    # TODO: Changes here for more than one world size
    assert(ps.is_pipeline_stage_before_split())

def test_is_pipeline_stage_after_split():
    if world_size == 1:
        assert(ps.is_pipeline_stage_after_split())
    # TODO: Changes here for more than one world size
    assert(ps.is_pipeline_stage_before_split())   

def test_is_pipeline_stage_at_split():
    assert(
        ps.is_pipeline_stage_at_split() == 
        (ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1))
        )

shanmugamr's avatar
shanmugamr committed
173
174
175
176
def test_destroy_model_parallel():
    ps.destroy_model_parallel()
    assert(ps._MODEL_PARALLEL_GROUP is None)
"""