test_addmm_tp.py 2.97 KB
Newer Older
1
2
3
4
5
import colossalai
import torch
import pytest
import torch.nn as nn
import torch.multiprocessing as mp
6
from colossalai.tensor import ColoTensor
ver217's avatar
ver217 committed
7
from colossalai.tensor import distspec
8
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
9
10
11
12
from colossalai.context import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
13
from colossalai.core import global_context as gpc
14
from _utils import tensor_shard_equal, tensor_equal
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
    Basically works like a linear layer but the weights are transposed.
    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.ones(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x


41
def init_1d_row(weight, bias):
42
    spec = TensorSpec(
ver217's avatar
ver217 committed
43
        distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
44
        ParallelAction(ComputePattern.TP1D))
45
46
    with DistSpecManager.no_grad():
        weight.set_spec(spec)
47
48


49
def init_1d_col(weight, bias):
50
    spec = TensorSpec(
ver217's avatar
ver217 committed
51
        distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
52
        ParallelAction(ComputePattern.TP1D))
53
54
55
56
57
    with DistSpecManager.no_grad():
        weight.set_spec(spec)
        bias.set_spec(spec)


58
def run_with_spec(spec_init_func):
59
    model = Conv1D(4, 16).cuda()
ver217's avatar
ver217 committed
60
61
    weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
    bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
62
    spec_init_func(weight, bias)
63
64
    x = torch.rand(2, 16).cuda()
    out = model(x)
65
    colo_out = torch.addmm(bias, x, weight)
66
    assert tensor_equal(out, colo_out)
67
68
69
    grad = torch.rand_like(out)
    out.backward(grad)
    colo_out.backward(grad)
70
71
    tensor_shard_equal(model.weight.grad, weight.grad)
    tensor_shard_equal(model.bias.grad, bias.grad)
72
73
74
75
76


def run_dist(rank, world_size, port):
    config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
77
78
    run_with_spec(init_1d_row)
    run_with_spec(init_1d_col)
79
80
81


@pytest.mark.dist
82
@pytest.mark.parametrize('world_size', [1, 4])
83
84
85
86
87
88
89
@rerun_if_address_is_in_use()
def test_addmm_1d(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
90
    test_addmm_1d(4)