test_moe_layer.py 7.11 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
import functools
7
import tempfile
8

9
10
import pytest
import torch
11
import torch.distributed as dist
12
import torch.multiprocessing as mp
13

14
15
from fair_dev.testing.testing import make_cudnn_deterministic
from fairscale.internal import torch_version
16
17
from fairscale.nn import MOELayer, Top2Gate

18
19
20
pytestmark = pytest.mark.skipif(
    not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required"
)
21

22
devices = ["cuda"]
23

24

25
26
27
28
29
30
31
def pg_worker(rank, world_size, init_file, func, *args):
    init_url = "file://" + init_file
    dist.init_process_group(backend=dist.Backend.NCCL, rank=rank, world_size=world_size, init_method=init_url)
    torch.cuda.set_device(rank)
    dist.all_reduce(torch.zeros(1).cuda())
    func(*args)
    dist.destroy_process_group()
32

33

34
35
36
37
38
39
def pg_test(world_size=torch.cuda.device_count()):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            tempfile_name = tempfile.mkstemp()[1]
            mp.spawn(pg_worker, args=(world_size, tempfile_name, func, *kwargs.values()), nprocs=world_size)
40

41
42
        globals()["test_" + func.__name__] = wrapper
        return func
43

44
    return decorator
45
46


47
@pg_test(world_size=1)
48
@pytest.mark.parametrize("device", devices)
49
def create(device):
50
51
52
53
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
54
    moe = MOELayer(gate, expert).to(device)
55
56


57
@pg_test(world_size=1)
58
@pytest.mark.parametrize("device", devices)
59
def expert_params(device):
60
61
62
63
64
65
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
    moe = MOELayer(gate, expert).to(device)
    for p in expert.parameters():
66
        assert p.expert is True, str(p.expert)
67
68


69
70
71
@pg_test()
@pytest.mark.parametrize("device", devices)
def forward(device):
72
    make_cudnn_deterministic()
73
    model_dim = 8
74
    num_experts = dist.get_world_size(dist.group.WORLD)
75
    input = torch.randn(4, 16, model_dim).to(device)
76
77
78
79
80
81
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
82
    assert output.shape == input.shape, f"{output.shape} != {input.shape}"
83
    # Re-assembled output should match input due to identity expert.
84
    torch.testing.assert_allclose(input, output)
85
86


87
88
89
@pg_test()
@pytest.mark.parametrize("device", devices)
def forward_multi(device):
90
    make_cudnn_deterministic()
91
92
93
94
    torch.set_printoptions(threshold=5000)
    num_local_experts = 4
    model_dim = 4
    num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
95
    input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
96
97
98
99
100
101
102
103
104
    gate = Top2Gate(model_dim, num_experts)
    experts = []
    for i in range(num_local_experts):
        expert = torch.nn.Linear(model_dim, model_dim, bias=False)
        # Use identity matrix
        expert.weight = torch.nn.Parameter(torch.eye(model_dim))
        experts += [expert]
    moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
    output = moe(input)
105
    assert output.shape == input.shape, f"{output.shape} != {input.shape}"
106
    # 90% of the input should have gone to an expert
107
108
109
    assert (
        len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
    ), f"{len(output.nonzero(as_tuple=False))} / {output.numel()}"
110
    # Except for zeros, re-assembled output should match input due to identity expert.
111
    torch.testing.assert_allclose(input, torch.where(output > 0, output, input))
112
113


114
115
116
117
118
119
120
121
# Test Gate which round-robin routes tokens to experts
class RoundRobinGate(torch.nn.Module):
    def __init__(self, model_dim, num_experts):
        super().__init__()
        self.model_dim = model_dim
        self.num_experts = num_experts

    def forward(self, input):
122
        s = input.shape[0]
123
        assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0"
124
        capacity = 2 * s // self.num_experts
125
        output = torch.zeros(s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
126
        for i in range(s):
127
            output[i, i % self.num_experts, i // self.num_experts] = 1.0
128
129
130
        return 0.0, output, output.bool()


131
132
133
@pg_test()
@pytest.mark.parametrize("device", devices)
def forward_routing(device):
134
    make_cudnn_deterministic()
135
136
    model_dim = 8
    num_experts = dist.get_world_size()
137
    input = torch.randn(4, 16, model_dim).to(device)
138
139
140
141
142
143
144
    gate = RoundRobinGate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use scaling matrix (each rank has a different scale)
    scale = dist.get_rank() + 1
    expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
145
    assert output.shape == input.shape, f"{output.shape} != {input.shape}"
146
    # Verify that each token was sent to the correct expert by checking its scale.
147
    t = input.shape[1]
148
149
    for i in range(t):
        expert = i % num_experts
150
        torch.testing.assert_allclose(input[:, i] * (expert + 1), output[:, i])
151
152


153
154
155
@pg_test()
@pytest.mark.parametrize("device", devices)
def forward_routing_multi(device):
156
    make_cudnn_deterministic()
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    model_dim = 8
    num_local_experts = 4
    num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
    input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
    gate = RoundRobinGate(model_dim, num_experts)
    experts = []
    for i in range(num_local_experts):
        expert = torch.nn.Linear(model_dim, model_dim, bias=False)
        # Use scaling matrix (each rank has a different scale)
        scale = dist.get_rank() * num_local_experts + i + 1
        expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
        experts += [expert]
    moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
    output = moe(input)
171
    assert output.shape == input.shape, f"{output.shape} != {input.shape}"
172
173
174
175
    # Verify that each token was sent to the correct expert by checking its scale.
    t = input.shape[1]
    for i in range(t):
        expert = i % num_experts
176
        torch.testing.assert_allclose(input[:, i] * (expert + 1), output[:, i])
177
178


179
180
181
@pg_test()
@pytest.mark.parametrize("device", devices)
def backward(device):
182
    make_cudnn_deterministic()
183
184
185
    loss = torch.nn.MSELoss()
    model_dim = 8
    num_experts = dist.get_world_size(dist.group.WORLD)
186
    input = torch.randn(4, 16, model_dim).to(device)
187
188
189
190
191
192
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
193
    assert output.shape == input.shape, f"{output.shape} != {input.shape}"
194
195
    output = loss(output, input)
    output.backward()
196
    torch.testing.assert_allclose(expert.weight.grad, torch.zeros_like(expert.weight))