test_parallel_state.py 5.57 KB
Newer Older
shanmugamr's avatar
shanmugamr committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import torch
import megatron.core.parallel_state as ps
from datetime import timedelta
import pytest

#TODO: Maybe get these values frome environment variables 
rank = torch.cuda.current_device()
world_size = 1 #torch.cuda.device_count()
tensor_model_parallel_size = 1
pipeline_model_parallel_size = 1
virtual_pipeline_model_parallel_size = None
pipeline_model_parallel_split_rank = None

def initialize_distributed():
    rank = torch.cuda.current_device()   
    print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
    torch.cuda.set_device(rank % torch.cuda.device_count())
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port
    torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method, timeout=timedelta(seconds=10))

def test_initialize_model_parallel():
    with pytest.raises(AssertionError):
        assert(ps.initialize_model_parallel())
    initialize_distributed()
    with pytest.raises(RuntimeError):
        assert(ps.initialize_model_parallel(tensor_model_parallel_size=2))
    with pytest.raises(RuntimeError):
        assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2))
    ps.initialize_model_parallel()

def test_other_initializations():
    assert(ps.model_parallel_is_initialized())
    assert(ps.get_model_parallel_group() is not None)
    assert(ps.get_tensor_model_parallel_group() is not None)
    assert(ps.get_pipeline_model_parallel_group() is not None)
    assert(ps.get_data_parallel_group() is not None)  
    assert(ps.get_embedding_group() is not None)  
    assert(ps.get_position_embedding_group() is not None)
    #TODO : Should change some of these test below to actually test code
    assert(ps.get_pipeline_model_parallel_first_rank() == 0)
    assert(ps.get_data_parallel_src_rank() == 0)
    assert(ps.get_pipeline_model_parallel_next_rank() == 0)
    assert(ps.get_pipeline_model_parallel_prev_rank() == 0)
    assert(ps.get_data_parallel_world_size() == world_size)
    assert(ps.get_data_parallel_rank() == 0)

def test_tensor_model_parellel_world_size():
    ps.set_tensor_model_parallel_world_size(world_size)
    assert(ps.get_tensor_model_parallel_world_size() == world_size)
    ps.set_tensor_model_parallel_world_size(None)
    assert(ps.get_tensor_model_parallel_world_size() == world_size)

def test_pipeline_model_parallel_world_size():
    ps.set_pipeline_model_parallel_world_size(world_size)
    assert(ps.get_pipeline_model_parallel_world_size() == world_size)
    ps.set_pipeline_model_parallel_world_size(None)
    assert(ps.get_pipeline_model_parallel_world_size() == world_size)

def test_tensor_model_parallel_rank():
    ps.set_tensor_model_parallel_rank(rank)
    assert(ps.get_tensor_model_parallel_rank() == rank)
    ps.set_tensor_model_parallel_rank(None)
    assert(ps.get_tensor_model_parallel_rank() == rank)    

def test_tensor_model_parallel_rank():
    ps.set_pipeline_model_parallel_rank(rank)
    assert(ps.get_pipeline_model_parallel_rank() == rank)
    ps.set_pipeline_model_parallel_rank(None)
    assert(ps.get_pipeline_model_parallel_rank() == rank)

def test_is_pipeline_first_stage():
    assert(ps.is_pipeline_first_stage(ignore_virtual=True))
    assert(ps.is_pipeline_first_stage())

def test_is_pipeline_last_stage():
    assert(
        ps.is_pipeline_last_stage(ignore_virtual=True) == (ps.get_pipeline_model_parallel_rank() == world_size-1)
        )
    assert(
        ps.is_pipeline_last_stage() == (ps.get_pipeline_model_parallel_rank() == world_size-1)
        )

def test_is_rank_in_embedding_group():
    assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS))
    if rank in ps._EMBEDDING_GLOBAL_RANKS:
        assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_first_stage())
    elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
        assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_last_stage())
    else:
        assert(ps.is_rank_in_embedding_group())

def test_is_rank_in_position_embedding_group():
    assert(ps.is_rank_in_position_embedding_group() == (rank in ps._POSITION_EMBEDDING_GLOBAL_RANKS))

def test_is_pipeline_stage_before_split():
    if world_size == 1:
        assert(ps.is_pipeline_stage_before_split())
    # TODO: Changes here for more than one world size
    assert(ps.is_pipeline_stage_before_split())

def test_is_pipeline_stage_after_split():
    if world_size == 1:
        assert(ps.is_pipeline_stage_after_split())
    # TODO: Changes here for more than one world size
    assert(ps.is_pipeline_stage_before_split())   

def test_is_pipeline_stage_at_split():
    assert(
        ps.is_pipeline_stage_at_split() == 
        (ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1))
        )

def test_virtual_pipeline_model_parallel_rank():
    ps.set_virtual_pipeline_model_parallel_rank(rank)
    assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)

def test_virtual_pipeline_model_parallel_rank():
    ps.set_virtual_pipeline_model_parallel_rank(rank)
    assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)

def test_get_virtual_pipeline_model_parallel_world_size():
    assert(ps.get_virtual_pipeline_model_parallel_world_size() == virtual_pipeline_model_parallel_size)

def test_get_tensor_model_parallel_src_rank():
    assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))

def global_memory_buffer():
    ps._set_global_memory_buffer()
    assert(ps.get_global_memory_buffer() is not None)