test_moe_layer.py 5.28 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
7
import os

8
9
import pytest
import torch
10
import torch.distributed as dist
11
12
13
14
15

from fairscale.nn import MOELayer, Top2Gate

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")

16
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO  # type: ignore
17

18
19
20
21
22
23
24
25
if torch.cuda.is_available():
    devices = ["cpu", "cuda"]
else:
    devices = ["cpu"]

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ:
26
    dist.init_process_group(backend=dist.Backend.MPI)
27
28
29
30
31


def setup_module(module):
    if "OMPI_COMM_WORLD_SIZE" not in os.environ:
        dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
32
33


34
def teardown_module(module):
35
36
    if "OMPI_COMM_WORLD_SIZE" not in os.environ:
        torch.distributed.destroy_process_group()
37
38
39
40


@pytest.mark.parametrize("device", devices)
def test_create(device):
41
42
43
44
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
45
    moe = MOELayer(gate, expert).to(device)
46
47


48
49
50
51
52
53
54
55
56
57
58
@pytest.mark.parametrize("device", devices)
def test_expert_params(device):
    model_dim = 8
    num_experts = 4
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim)
    moe = MOELayer(gate, expert).to(device)
    for p in expert.parameters():
        assert p.expert is True


59
60
61
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward(device):
62
    model_dim = 8
63
    num_experts = dist.get_world_size(dist.group.WORLD)
64
    input = torch.randn(4, 16, model_dim).to(device)
65
66
67
68
69
70
71
72
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Re-assembled output should match input due to identity expert.
73
    assert torch.allclose(input, output)
74
75


76
77
78
79
80
81
82
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_multi(device):
    torch.set_printoptions(threshold=5000)
    num_local_experts = 4
    model_dim = 4
    num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
83
    input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    gate = Top2Gate(model_dim, num_experts)
    experts = []
    for i in range(num_local_experts):
        expert = torch.nn.Linear(model_dim, model_dim, bias=False)
        # Use identity matrix
        expert.weight = torch.nn.Parameter(torch.eye(model_dim))
        experts += [expert]
    moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # 90% of the input should have gone to an expert
    assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
    # Except for zeros, re-assembled output should match input due to identity expert.
    assert torch.allclose(input, torch.where(output > 0, output, input))


100
101
102
103
104
105
106
107
# Test Gate which round-robin routes tokens to experts
class RoundRobinGate(torch.nn.Module):
    def __init__(self, model_dim, num_experts):
        super().__init__()
        self.model_dim = model_dim
        self.num_experts = num_experts

    def forward(self, input):
108
        s = input.shape[0]
109
110
        assert s % self.num_experts == 0
        capacity = 2 * s // self.num_experts
111
        output = torch.zeros(s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
112
        for i in range(s):
113
            output[i, i % self.num_experts, i // self.num_experts] = 1.0
114
115
116
117
118
119
120
121
        return 0.0, output, output.bool()


@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_routing(device):
    model_dim = 8
    num_experts = dist.get_world_size()
122
    input = torch.randn(4, 16, model_dim).to(device)
123
124
125
126
127
128
129
130
131
    gate = RoundRobinGate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use scaling matrix (each rank has a different scale)
    scale = dist.get_rank() + 1
    expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    # Verify that each token was sent to the correct expert by checking its scale.
132
    t = input.shape[1]
133
134
    for i in range(t):
        expert = i % num_experts
135
        assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
136
137


138
139
140
141
142
143
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_backward(device):
    loss = torch.nn.MSELoss()
    model_dim = 8
    num_experts = dist.get_world_size(dist.group.WORLD)
144
    input = torch.randn(4, 16, model_dim).to(device)
145
146
147
148
149
150
151
152
153
154
    gate = Top2Gate(model_dim, num_experts)
    expert = torch.nn.Linear(model_dim, model_dim, bias=False)
    # Use identity matrix
    expert.weight = torch.nn.Parameter(torch.eye(model_dim))
    moe = MOELayer(gate, expert).to(device)
    output = moe(input)
    assert output.shape == input.shape
    output = loss(output, input)
    output.backward()
    assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))