benchmark_mlp.py 3.78 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
import torch
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
3
4
5
6
7
8
9
10
11
12
from fmoe import FMoETransformerMLP
from fmoe.gates import NaiveGate
from moe import BruteForceMoELinear
import time
import sys
import os


rank = None
world_size = None
Sengxian's avatar
Sengxian committed
13
dev_name_default = "cuda:0"
Rick Ho's avatar
Rick Ho committed
14
15
16


class BruteForceMoE(nn.Module):
Sengxian's avatar
Sengxian committed
17
18
19
20
21
22
23
24
25
26
27
28
29
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
        world_size=1,
        mp_group=None,
        activation=torch.nn.functional.gelu,
        gate=NaiveGate,
        top_k=1,
        pre_lnorm=False,
    ):
        assert world_size == 1, "Distributed brute force is not supported"
Rick Ho's avatar
Rick Ho committed
30
        super().__init__()
Sengxian's avatar
Sengxian committed
31
32
33
        self.mlp = BruteForceMoELinear(
            activation, num_expert, d_model, d_hidden, 1, top_k
        )
Rick Ho's avatar
Rick Ho committed
34
35
36
37
38
39
40
41
42
        self.top_k = top_k
        self.gate = gate(d_model, num_expert, world_size, top_k)
        self.pre_lnorm = pre_lnorm
        self.layer_norm = nn.LayerNorm(d_model)
        self.d_model = d_model

    def forward(self, inp):
        if self.pre_lnorm:
            inp = self.layer_norm(inp)
Rick Ho's avatar
Rick Ho committed
43
        gate_top_k_idx, gate_score = self.gate(inp)
Rick Ho's avatar
Rick Ho committed
44
        inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
Rick Ho's avatar
Rick Ho committed
45
        x = self.mlp(inp, gate_top_k_idx, gate_score)
Rick Ho's avatar
Rick Ho committed
46
47
48
49
50
51
52
53
54
        if not self.pre_lnorm:
            x = self.layer_norm(x)
        return x


def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
    if rank == 0:
Sengxian's avatar
Sengxian committed
55
56
57
58
59
60
61
62
63
64
65
        print(
            "Performance test of {} mm size {} {}x{} experts {}x{} topk {}".format(
                MOELayer.__name__,
                batch_size,
                in_feat,
                hidden_feat,
                world_size,
                num_expert,
                top_k,
            )
        )
Rick Ho's avatar
Rick Ho committed
66
    if world_size > 1:
Sengxian's avatar
Sengxian committed
67
        dev_name = "cuda"
Rick Ho's avatar
Rick Ho committed
68
69
70
71
72
73
    else:
        dev_name = dev_name_default

    inp = torch.rand(batch_size, in_feat).cuda(dev_name)
    inp.requires_grad = True

Sengxian's avatar
Sengxian committed
74
75
76
77
78
79
80
    moe = MOELayer(
        num_expert=num_expert,
        d_model=in_feat,
        d_hidden=hidden_feat,
        world_size=world_size,
        top_k=top_k,
    ).cuda(dev_name)
Rick Ho's avatar
Rick Ho committed
81
82
83
84
85
86
87
    moe.train()

    # warm up
    for _ in range(4):
        _ = moe(inp)

    n_runs = 16
Sengxian's avatar
Sengxian committed
88
89
90
91
    tott = 0.0
    backt = 0.0
    maxt = 0.0
    sqtot = 0.0
Rick Ho's avatar
Rick Ho committed
92
93
94
95
96
97
98
99
100
101
102
103
    for i in range(n_runs):
        ts = time.time()
        o = moe(inp)
        te = time.time()

        loss = o.sum()

        bts = time.time()
        loss.backward()
        bte = time.time()

        tott += te - ts
Sengxian's avatar
Sengxian committed
104
        sqtot += (te - ts) ** 2
Rick Ho's avatar
Rick Ho committed
105
106
107
        maxt = max(maxt, te - ts)
        backt += bte - bts

Sengxian's avatar
Sengxian committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    gflops = (
        2e-9
        * n_runs
        * (
            in_feat * hidden_feat * batch_size * top_k * 2
            + batch_size * in_feat * num_expert
        )
        / tott
    )
    print(
        "Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs".format(
            tott * 1e3 / n_runs,
            maxt * 1e3,
            (sqtot / n_runs - (tott / n_runs) ** 2) * 1e3 * top_k / n_runs,
            backt * 1e3 / n_runs,
            gflops,
        )
    )


if __name__ == "__main__":
    if int(os.environ["WORLD_SIZE"]) > 1:
        torch.distributed.init_process_group(backend="nccl")
Rick Ho's avatar
Rick Ho committed
131
132
133
134
135
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
    else:
        rank = 0
        world_size = 1
Sengxian's avatar
Sengxian committed
136
137
138
139
140
141
    batch_size = int(os.environ.get("BATCH_SIZE", "4096"))
    d_model = int(os.environ.get("D_MODEL", "1024"))
    d_hidden = int(os.environ.get("D_HIDDEN", "4096"))
    num_expert = int(os.environ.get("NUM_EXPERT", "64"))
    top_k = int(os.environ.get("TOP_K", "2"))
    benchmark_mlp(FMoETransformerMLP, batch_size, d_model, d_hidden, num_expert, top_k)
Rick Ho's avatar
Rick Ho committed
142
    if world_size == 1:
Sengxian's avatar
Sengxian committed
143
        benchmark_mlp(BruteForceMoE, batch_size, d_model, d_hidden, num_expert, top_k)