test_moe_layer.py 5.3 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
7
import os

8
9
import pytest
import torch
10
import torch.distributed as dist
11
12
13
14
15

from fairscale.nn import MOELayer, Top2Gate

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")

16
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO  # type: ignore
17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
if torch.cuda.is_available():
    devices = ["cpu", "cuda"]
else:
    devices = ["cpu"]

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ:
    dist.init_process_group(backend=dist.Backend.MPI)


def setup_module(module):
    if "OMPI_COMM_WORLD_SIZE" not in os.environ:
        dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
32
33


34
35
36
37
38
39
40
def teardown_module(module):
    if "OMPI_COMM_WORLD_SIZE" not in os.environ:
        torch.distributed.destroy_process_group()


@pytest.mark.parametrize("device", devices)
def test_create(device):
41
42
43
44
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
45
    moe = MOELayer(gate, expert).to(device)
46
47


48
49
50
51
52
53
54
55
56
57
58
@pytest.mark.parametrize("device", devices)
def test_expert_params(device):
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
    moe = MOELayer(gate, expert).to(device)
    for p in expert.parameters():
        assert p.expert is True


59
60
61
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward(device):
62
    model_dim = 8
63
    num_experts = dist.get_world_size(dist.group.WORLD)
64
    input = torch.randn(1, 4, 16, model_dim).to(device)
65
66
67
68
69
70
71
72
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Re-assembled output should match input due to identity expert.
73
    assert torch.allclose(input, output)
74
75


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_multi(device):
    torch.set_printoptions(threshold=5000)
    num_local_experts = 4
    model_dim = 4
    num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
    input = torch.randn(num_local_experts, 4, 16, model_dim).to(device)
    gate = Top2Gate(model_dim, num_experts)
    experts = []
    for i in range(num_local_experts):
        expert = torch.nn.Linear(model_dim, model_dim, bias=False)
        # Use identity matrix
        expert.weight = torch.nn.Parameter(torch.eye(model_dim))
        experts += [expert]
    moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # 90% of the input should have gone to an expert
    assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
    # Except for zeros, re-assembled output should match input due to identity expert.
    assert torch.allclose(input, torch.where(output > 0, output, input))


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Test Gate which round-robin routes tokens to experts
class RoundRobinGate(torch.nn.Module):
    def __init__(self, model_dim, num_experts):
        super().__init__()
        self.model_dim = model_dim
        self.num_experts = num_experts

    def forward(self, input):
        g, s, _ = input.shape
        assert s % self.num_experts == 0
        capacity = 2 * s // self.num_experts
        output = torch.zeros(g, s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
        for i in range(s):
            output[:, i, i % self.num_experts, i // self.num_experts] = 1.0
        return 0.0, output, output.bool()


@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_routing(device):
    model_dim = 8
    num_experts = dist.get_world_size()
122
    input = torch.randn(1, 4, 16, model_dim).to(device)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    gate = RoundRobinGate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use scaling matrix (each rank has a different scale)
    scale = dist.get_rank() + 1
    expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Verify that each token was sent to the correct expert by checking its scale.
    t = input.shape[2]
    for i in range(t):
        expert = i % num_experts
        assert torch.allclose(input[:, :, i] * (expert + 1), output[:, :, i])


138
139
140
141
142
143
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_backward(device):
    loss = torch.nn.MSELoss()
    model_dim = 8
    num_experts = dist.get_world_size(dist.group.WORLD)
144
    input = torch.randn(1, 4, 16, model_dim).to(device)
145
146
147
148
149
150
151
152
153
154
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    output = loss(output, input)
    output.backward()
    assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))