test_device_mesh.py 571 Bytes
Newer Older
1
2
import torch

3
4
from colossalai.device.device_mesh import DeviceMesh

5
6

def test_device_mesh():
7
    physical_mesh_id = torch.arange(0, 16)
8
9
10
11
12
13
    mesh_shape = (4, 4)
    # [[0, 1, 2, 3],
    #  [4, 5, 6, 7],
    #  [8, 9, 10,11],
    #  [12,13,14,15]]
    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
14
15
16
    assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
    assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
    assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
17
18
19
20


if __name__ == '__main__':
    test_device_mesh()