test_parallel_state.py 4.49 KB
Newer Older
shanmugamr's avatar
shanmugamr committed
1
2
3
import torch
import megatron.core.parallel_state as ps
import pytest
4
from tests.unit_tests.test_utilities import Utils
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
5
import os 
shanmugamr's avatar
shanmugamr committed
6

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
7
8
rank = Utils.rank
world_size = Utils.world_size
shanmugamr's avatar
shanmugamr committed
9

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
10
def test_initialize__and_destroy_model_parallel():
shanmugamr's avatar
shanmugamr committed
11
12
    with pytest.raises(AssertionError):
        assert(ps.initialize_model_parallel())
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
13
    Utils.initialize_distributed()
shanmugamr's avatar
shanmugamr committed
14
    with pytest.raises(RuntimeError):
shanmugamr's avatar
shanmugamr committed
15
16
17
18
19
        assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))
    with pytest.raises(RuntimeError):
        assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))
    with pytest.raises(RuntimeError):
        assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))
shanmugamr's avatar
shanmugamr committed
20
    with pytest.raises(RuntimeError):
shanmugamr's avatar
shanmugamr committed
21
        assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
22
    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
shanmugamr's avatar
shanmugamr committed
23
24
25
26
27
28

    assert(ps.model_parallel_is_initialized())
    assert(ps.get_model_parallel_group() is not None)
    assert(ps.get_tensor_model_parallel_group() is not None)
    assert(ps.get_pipeline_model_parallel_group() is not None)
    assert(ps.get_data_parallel_group() is not None)  
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
29
30
    Utils.destroy_model_parallel()
    assert(ps._MODEL_PARALLEL_GROUP is None)
shanmugamr's avatar
shanmugamr committed
31
32

def test_pipeline_parallel_initializations():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
33
34
    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
    assert(ps.get_pipeline_model_parallel_first_rank() == rank % 2 )
shanmugamr's avatar
shanmugamr committed
35
    assert(ps.get_data_parallel_src_rank() == rank)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
36
37
38
39
    assert(ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size))
    assert(ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size))
    Utils.destroy_model_parallel()

shanmugamr's avatar
shanmugamr committed
40
def test_data_parallel_initializations():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
41
    Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
42
    assert(ps.get_data_parallel_src_rank() == rank)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
43
    assert(ps.get_data_parallel_world_size() == 1)
shanmugamr's avatar
shanmugamr committed
44
    assert(ps.get_data_parallel_rank() == 0)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
45
    Utils.destroy_model_parallel()
shanmugamr's avatar
shanmugamr committed
46
    
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
47

shanmugamr's avatar
shanmugamr committed
48
def test_tensor_model_parellel_world_size():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
49
    Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
50
51
52
    assert(ps.get_tensor_model_parallel_world_size() == world_size)
    ps.set_tensor_model_parallel_world_size(None)
    assert(ps.get_tensor_model_parallel_world_size() == world_size)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
53
54
    Utils.destroy_model_parallel()
    
shanmugamr's avatar
shanmugamr committed
55
56

def test_pipeline_model_parallel_world_size():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
57
    Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
58
59
60
    assert(ps.get_pipeline_model_parallel_world_size() == world_size)
    ps.set_pipeline_model_parallel_world_size(None)
    assert(ps.get_pipeline_model_parallel_world_size() == world_size)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
61
62
    Utils.destroy_model_parallel()    
    
shanmugamr's avatar
shanmugamr committed
63
64

def test_tensor_model_parallel_rank():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
65
    Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
66
67
68
    assert(ps.get_tensor_model_parallel_rank() == rank)
    ps.set_tensor_model_parallel_rank(None)
    assert(ps.get_tensor_model_parallel_rank() == rank)    
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
69
70
    Utils.destroy_model_parallel()    
    
shanmugamr's avatar
shanmugamr committed
71

shanmugamr's avatar
shanmugamr committed
72
def test_pipeline_model_parallel_rank():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
73
    Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
74
75
76
    assert(ps.get_pipeline_model_parallel_rank() == rank)
    ps.set_pipeline_model_parallel_rank(None)
    assert(ps.get_pipeline_model_parallel_rank() == rank)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
77
    Utils.destroy_model_parallel()
shanmugamr's avatar
shanmugamr committed
78
    
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
79

shanmugamr's avatar
shanmugamr committed
80
def test_is_pipeline_first_stage():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
81
    Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
82
83
    assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0))
    assert(ps.is_pipeline_first_stage() == (rank == 0))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
84
85
    Utils.destroy_model_parallel()
    
shanmugamr's avatar
shanmugamr committed
86
87

def test_is_pipeline_last_stage():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
88
    Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
89
90
    assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1))
    assert(ps.is_pipeline_last_stage() == (rank == world_size-1))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
91
92
    Utils.destroy_model_parallel()
    
shanmugamr's avatar
shanmugamr committed
93
94

def test_virtual_pipeline_model_parallel_rank():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
95
    Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
96
97
    ps.set_virtual_pipeline_model_parallel_rank(rank)
    assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
98
99
    Utils.destroy_model_parallel()
    
shanmugamr's avatar
shanmugamr committed
100
101

def test_get_tensor_model_parallel_src_rank():
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
102
    Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
shanmugamr's avatar
shanmugamr committed
103
    assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
104
    Utils.destroy_model_parallel()