test_moe_layer.py 6.33 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
import os
7
import tempfile
8

9
10
import pytest
import torch
11
import torch.distributed as dist
12
13
14
15
16

from fairscale.nn import MOELayer, Top2Gate

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")

17
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO  # type: ignore
18

19
20
21
22
23
if torch.cuda.is_available():
    devices = ["cpu", "cuda"]
else:
    devices = ["cpu"]

24
25
URL = "file://" + tempfile.mkstemp()[1]

26
if "OMPI_COMM_WORLD_SIZE" in os.environ:
27
    dist.init_process_group(backend=dist.Backend.MPI, init_method=URL)
28
29
30
31


def setup_module(module):
    if "OMPI_COMM_WORLD_SIZE" not in os.environ:
32
        dist.init_process_group(backend=BACKEND, rank=0, world_size=1, init_method=URL)
33
34


35
def teardown_module(module):
36
37
    if "OMPI_COMM_WORLD_SIZE" not in os.environ:
        torch.distributed.destroy_process_group()
38
39
40
41


@pytest.mark.parametrize("device", devices)
def test_create(device):
42
43
44
45
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
46
    moe = MOELayer(gate, expert).to(device)
47
48


49
50
51
52
53
54
55
56
57
58
59
@pytest.mark.parametrize("device", devices)
def test_expert_params(device):
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
    moe = MOELayer(gate, expert).to(device)
    for p in expert.parameters():
        assert p.expert is True


60
61
62
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward(device):
63
    model_dim = 8
64
    num_experts = dist.get_world_size(dist.group.WORLD)
65
    input = torch.randn(4, 16, model_dim).to(device)
66
67
68
69
70
71
72
73
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Re-assembled output should match input due to identity expert.
74
    assert torch.allclose(input, output)
75
76


77
78
79
80
81
82
83
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_multi(device):
    torch.set_printoptions(threshold=5000)
    num_local_experts = 4
    model_dim = 4
    num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
84
    input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    gate = Top2Gate(model_dim, num_experts)
    experts = []
    for i in range(num_local_experts):
        expert = torch.nn.Linear(model_dim, model_dim, bias=False)
        # Use identity matrix
        expert.weight = torch.nn.Parameter(torch.eye(model_dim))
        experts += [expert]
    moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # 90% of the input should have gone to an expert
    assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
    # Except for zeros, re-assembled output should match input due to identity expert.
    assert torch.allclose(input, torch.where(output > 0, output, input))


101
102
103
104
105
106
107
108
# Test Gate which round-robin routes tokens to experts
class RoundRobinGate(torch.nn.Module):
    def __init__(self, model_dim, num_experts):
        super().__init__()
        self.model_dim = model_dim
        self.num_experts = num_experts

    def forward(self, input):
109
        s = input.shape[0]
110
111
        assert s % self.num_experts == 0
        capacity = 2 * s // self.num_experts
112
        output = torch.zeros(s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
113
        for i in range(s):
114
            output[i, i % self.num_experts, i // self.num_experts] = 1.0
115
116
117
118
119
120
121
122
        return 0.0, output, output.bool()


@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_routing(device):
    model_dim = 8
    num_experts = dist.get_world_size()
123
    input = torch.randn(4, 16, model_dim).to(device)
124
125
126
127
128
129
130
131
132
    gate = RoundRobinGate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use scaling matrix (each rank has a different scale)
    scale = dist.get_rank() + 1
    expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Verify that each token was sent to the correct expert by checking its scale.
133
    t = input.shape[1]
134
135
    for i in range(t):
        expert = i % num_experts
136
        assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
137
138


139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_routing_multi(device):
    model_dim = 8
    num_local_experts = 4
    num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
    input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
    gate = RoundRobinGate(model_dim, num_experts)
    experts = []
    for i in range(num_local_experts):
        expert = torch.nn.Linear(model_dim, model_dim, bias=False)
        # Use scaling matrix (each rank has a different scale)
        scale = dist.get_rank() * num_local_experts + i + 1
        expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
        experts += [expert]
    moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Verify that each token was sent to the correct expert by checking its scale.
    t = input.shape[1]
    for i in range(t):
        expert = i % num_experts
        assert torch.allclose(input[:, i] * (expert + 1), output[:, i])


164
165
166
167
168
169
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_backward(device):
    loss = torch.nn.MSELoss()
    model_dim = 8
    num_experts = dist.get_world_size(dist.group.WORLD)
170
    input = torch.randn(4, 16, model_dim).to(device)
171
172
173
174
175
176
177
178
179
180
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    output = loss(output, input)
    output.backward()
    assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))