test_addmm_tp.py 2.4 KB
Newer Older
1
import pytest
2
import torch
3
import torch.nn as nn
4
5
6
7
8

import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
    Basically works like a linear layer but the weights are transposed.
    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.ones(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x


35
def run_with_spec(spec_init_func, split_bias):
36
    model = Conv1D(4, 16).cuda()
37
38
    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(tp_degree=world_size)
39
40
41
42

    weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
    bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))

43
44
45
46
    spec_init_func(weight, pg)
    if split_bias:
        spec_init_func(bias, pg)

47
48
    x = torch.rand(2, 16).cuda()
    out = model(x)
49
    colo_out = torch.addmm(bias, x, weight)
50
    colo_out = colo_out.to_replicate()
51
    assert tensor_equal(out, colo_out)
52
53
54
    grad = torch.rand_like(out)
    out.backward(grad)
    colo_out.backward(grad)
55
56
    tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
    tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
57
58
59


def run_dist(rank, world_size, port):
60
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
61
62
    run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False)
    run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True)
63
64
65


@pytest.mark.dist
66
@pytest.mark.parametrize('world_size', [1, 4])
67
68
@rerun_if_address_is_in_use()
def test_addmm_1d(world_size):
69
    spawn(run_dist, world_size)
70
71
72


if __name__ == '__main__':
73
    test_addmm_1d(4)