"official/projects/nhnet/decoder.py" did not exist on "4b0cec67221923d05d854631a221bd3dc4606664"
test_top2gating.py 2.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from fairscale.nn import Top2Gate
from fairscale.nn.moe.top2gate import top2gating

12
13
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")

14
15
16
17
18

def test_create():
    gate = Top2Gate(4, 8)


19
20
21
22
23
24
@skip_if_no_cuda
def test_create_cuda():
    gate = Top2Gate(4, 8).cuda()


def do_test_forward(device):
25
    torch.manual_seed(3)
26
    input = torch.randn(12, 4).to(device)
27
    gate = Top2Gate(4, 6).to(device)
28
29
    capacity = 2 * 12 // 6
    l_aux, combine_weights, dispatch_mask = gate(input)
30
    assert pytest.approx(l_aux.item(), rel=0.01) == 0.0267, l_aux
31
32
    assert combine_weights.shape == (12, 6, 4)
    assert dispatch_mask.shape == (12, 6, 4)
33
    assert torch.equal(combine_weights.bool(), dispatch_mask)
34
    assert torch.all(torch.sum(dispatch_mask, axis=(0, 2)) <= capacity)
35
36
37
    assert torch.all(combine_weights >= 0.0)
    assert torch.all(combine_weights <= 1.0)
    weights_sum = torch.sum(combine_weights).item()
38
    assert round(weights_sum) == pytest.approx(weights_sum), weights_sum
39
    # For this random seed, we get 12 slots filled.
40
    assert weights_sum == pytest.approx(12.0), weights_sum
41
42


43
44
45
46
47
48
49
50
51
def test_forward_cpu():
    do_test_forward("cpu")


@skip_if_no_cuda
def test_forward_cuda():
    do_test_forward("cuda")


52
# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
53
def test_expert1_overflow():
54
55
    num_tokens = 8
    num_experts = 4
56
    logits = torch.randn(num_tokens, num_experts)
57
    logits[:, 0] = torch.max(logits, dim=1).values + 1  # Force overflow
58
    top1s = torch.argmax(logits, dim=1)
59
60
    assert top1s.eq(0).all(), top1s
    _, __, dispatch_mask = top2gating(logits)
61
    capacity = 2 * num_tokens // num_experts
62
63
64
65
66
67

    for i in range(num_tokens):
        if i < capacity:
            assert dispatch_mask[i][0][i]
        else:
            assert not dispatch_mask[i][0].any()