test_dropout.py 1.87 KB
Newer Older
FoolPlayer's avatar
FoolPlayer committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pytest
import torch
import torch.nn.functional as F

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.testing import rerun_if_address_is_in_use, spawn

CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)


def check_dropout(rank, world_size, port):
    disable_existing_loggers()
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')

    # prepare data
    input = torch.randn(5, 4).to('cuda')
    dropout = Dropout1D(p=0.4).to('cuda')
    output_list = []
    # compare the dropout pattern in each device
    for i in range(2):
        output = dropout(input)
        output_list.append(output)
        dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)]
        torch.distributed.all_gather(dist_output_list, output)
        for j in range(world_size):
            for k in range(world_size):
                if j != k:
                    mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0)
                    assert torch.all(
                        mask
                    ) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}"
    # compare the dropout pattern in loacl device
    for i in range(len(output_list)):
        for j in range(len(output_list)):
            if i != j:
                mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0)
                assert torch.all(
                    mask
                ) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}"


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dropout():
    spawn(check_dropout, 2)


if __name__ == '__main__':
    test_dropout()