test_ddp_ignore_params.py 2.91 KB
Newer Older
1
2
3
4
5
6
7
8
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
9
from colossalai.gemini import ChunkManager
10
from functools import partial
11
from colossalai.nn.parallel import ColoDDP, ZeroDDP
12
from colossalai.gemini.gemini_mgr import GeminiManager
13
from typing import Callable
14
15
16
17
import torch.distributed as dist
import os
import random
import numpy as np
18
from colossalai.tensor import ProcessGroup
19
20
21
22
23
24
25
26
27
28
29
30


def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def init_ddp(module: torch.nn.Module) -> ColoDDP:
31
32
    pg = ProcessGroup()
    return ColoDDP(module, process_group=pg)
33
34


35
36
37
38
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
    pg = ProcessGroup()
    chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
    chunk_manager = ChunkManager(chunk_size, pg)
39
    gemini_manager = GeminiManager('cuda', chunk_manager)
40
    return ZeroDDP(module, gemini_manager)
41
42
43
44
45
46
47
48
49
50
51
52
53


class Net(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(3, 3, bias=False)
        self.fc2 = torch.nn.Linear(3, 1, bias=False)

    def forward(self, x):
        return self.fc2(self.fc1(x))


54
def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
55
56
57
58
59
60
61
62
63
64
65
    with ColoInitContext(device=get_current_device()):
        model = Net().cuda()
    w1 = model.fc1.weight
    w2 = model.fc2.weight
    ddp_cls.set_params_to_ignore([w2])
    model = init_ddp_func(model)
    x = torch.rand(2, 3, device=get_current_device())
    logits = model(x)
    loss = torch.sum(logits)
    model.backward(loss)
    w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
66
    dist.all_gather(w1_grads, w1.grad)
67
68
69
70
71
72
73
74
75
76
    assert torch.equal(w1_grads[0], w1_grads[1])
    w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
    dist.all_gather(w2_grads, w2.grad)
    assert not torch.equal(w2_grads[0], w2_grads[1])


def run_dist(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    set_seed(dist.get_rank())
    run_fwd_bwd(ColoDDP, init_ddp)
77
78
    run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False))
    run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True))
79
80
81
82
83
84
85
86
87
88
89
90


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_ddp_ignore_params(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_ddp_ignore_params(2)