test_data.py 809 Bytes
Newer Older
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
1
import torch
liangjing's avatar
liangjing committed
2
3

from megatron.core.tensor_parallel.data import broadcast_data
4
from tests.unit_tests.test_utilities import Utils
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
5

liangjing's avatar
liangjing committed
6

Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
7
def test_broadcast_data():
liangjing's avatar
liangjing committed
8
    Utils.initialize_model_parallel(2, 4)
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
9
    input_data = {
liangjing's avatar
liangjing committed
10
11
12
13
14
15
16
17
18
        0: torch.ones((8, 8)).cuda() * 0.0,
        1: torch.ones((8, 8)).cuda() * 1.0,
        2: torch.ones((8, 8)).cuda() * 2.0,
        3: torch.ones((8, 8)).cuda() * 3.0,
        4: torch.ones((8, 8)).cuda() * 4.0,
        5: torch.ones((8, 8)).cuda() * 5.0,
        6: torch.ones((8, 8)).cuda() * 6.0,
        7: torch.ones((8, 8)).cuda() * 7.0,
    }
Shanmugam Ramasamy's avatar
Shanmugam Ramasamy committed
19
    dtype = torch.float32
liangjing's avatar
liangjing committed
20
21
22
23
    actual_output = broadcast_data([0, 1], input_data, dtype)
    assert torch.equal(actual_output[0], input_data[0])
    assert torch.equal(actual_output[1], input_data[1])
    Utils.destroy_model_parallel()