example_fusedmoe_torch.py 8.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import math
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional


# Reference code in PyTorch
class ExpertTorch(nn.Module):
    def __init__(self, config: Dict, d_expert: Optional[int] = None):
        super().__init__()
        self.config = config
        self.act_fn = nn.SiLU()
        self.d_hidden: int = config["d_hidden"]
        self.d_expert: int = config["d_expert"] if d_expert is None else d_expert

        self.W_gate = nn.Linear(self.d_hidden, self.d_expert, bias=False)
        self.W_up = nn.Linear(self.d_hidden, self.d_expert, bias=False)
        self.W_down = nn.Linear(self.d_expert, self.d_hidden, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = self.act_fn(self.W_gate(x))
        out = self.W_down(gate * self.W_up(x))
        return out


class MoEGateTorch(nn.Module):
    def __init__(self, config: Dict):
        super().__init__()
        self.top_k: int = config["n_experts_per_token"]
        self.num_experts: int = config["n_routed_experts"]
        self.d_hidden: int = config["d_hidden"]

        self.W_g = nn.Linear(self.d_hidden, self.num_experts, bias=False)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        logits = self.W_g(x)
        scores = logits.softmax(dim=-1)
        topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        return topk_indices, topk_scores


class MoETorch(nn.Module):
    def __init__(self, config: Dict):
        super().__init__()
        self.config = config
47
        self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])])
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        self.gating_network = MoEGateTorch(config)
        shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
        self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shared_output = self.shared_expert(x)
        expert_indices, expert_scores = self.gating_network(x)
        batch_size, seq_len, hidden_dim = x.shape
        orig_shape = x.shape
        x_flat = x.view(-1, hidden_dim)
        flat_expert_indices = expert_indices.view(-1)
        flat_expert_weights = expert_scores.view(-1, 1)
        routed_output_flat = self.moe_infer(x_flat, flat_expert_indices, flat_expert_weights)

        routed_output = routed_output_flat.view(*orig_shape)
        return routed_output + shared_output

    @torch.no_grad()
66
    def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        expert_cache = torch.zeros_like(x)
        # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
        # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
        # test_expert_ups = torch.zeros((self.config["n_routed_experts"], self.config["d_hidden"], self.config["d_expert"]))
        # test_expert_tokens_num = torch.zeros((self.config["n_routed_experts"]))

        idxs = flat_expert_indices.argsort()
        counts = flat_expert_indices.bincount().cpu().numpy()
        tokens_per_expert = counts.cumsum()
        num_per_tok = self.config["n_experts_per_token"]
        token_idxs = idxs // num_per_tok
        for expert_id, end_idx in enumerate(tokens_per_expert):
            start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
            if start_idx == end_idx:
                continue

            expert = self.experts[expert_id]
            exp_token_idxs = token_idxs[start_idx:end_idx]
            expert_tokens = x[exp_token_idxs]
            expert_out = expert(expert_tokens)

            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
89
            expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
90
91
92
93
94
95
96

        return expert_cache


def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
    """
    Reference implementation of DeepSeek-style Mixture of Experts using PyTorch.
97

98
99
100
101
102
    Args:
        data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
            - input: Input tensor of shape [batch_size, seq_len, hidden_dim]
            - weights: Dictionary containing model weights
            - config: Dictionary containing model configuration parameters
103

104
105
106
107
108
109
110
111
112
    Returns:
        Tuple containing:
            - output: Processed tensor [batch_size, seq_len, d_model]
    """
    input_tensor, weights, config = data
    num_experts = config["n_routed_experts"]
    moe = MoETorch(config)

    # Fill in the given weights of the model
113
    moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"])
114
115

    for i in range(num_experts):
116
117
118
        gate_proj_weight = weights[f"experts.{i}.0.weight"]
        up_proj_weight = weights[f"experts.{i}.1.weight"]
        down_proj_weight = weights[f"experts.{i}.2.weight"]
119
120
121
122
123
124

        # Transpose weights to match expected shape for nn.Linear
        moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t())
        moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t())
        moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t())

125
126
127
    moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t())
    moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t())
    moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t())
128
129
130
131
132
133
134
135
136

    output = moe(input_tensor)

    return output


# Input generation for the reference code


137
138
139
def generate_input(
    dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int
) -> Tuple[torch.Tensor, Dict, Dict]:
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    # Really dumb but for now _ isn't parsing correctly.
    d_hidden = dhidden
    d_expert = dexpert
    n_routed_experts = nroutedexperts
    n_shared_experts = nsharedexperts
    n_experts_per_token = nexpertspertoken
    batch_size = bs
    seq_len = seqlen

    config = {
        "d_hidden": d_hidden,
        "d_expert": d_expert,
        "n_routed_experts": n_routed_experts,
        "n_shared_experts": n_shared_experts,
        "n_experts_per_token": n_experts_per_token,
        "batch_size": batch_size,
        "seq_len": seq_len,
    }

159
    gen = torch.Generator(device="cuda")
160
161
162
163
164
165
    gen.manual_seed(seed)

    num_experts = n_routed_experts
    expert_dim = d_expert
    weights = {}

166
    input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous()
167
168

    # Initialize router weights
169
    weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden)
170
171

    for i in range(num_experts):
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        weights[f"experts.{i}.0.weight"] = torch.randn(
            (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
        ) / math.sqrt(expert_dim)

        weights[f"experts.{i}.1.weight"] = torch.randn(
            (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
        ) / math.sqrt(expert_dim)

        weights[f"experts.{i}.2.weight"] = torch.randn(
            (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen
        ) / math.sqrt(d_hidden)

    weights["shared_experts.0.weight"] = torch.randn(
        (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
    ) / math.sqrt(expert_dim * n_shared_experts)
    weights["shared_experts.1.weight"] = torch.randn(
        (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
    ) / math.sqrt(expert_dim * n_shared_experts)
    weights["shared_experts.2.weight"] = torch.randn(
        (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen
    ) / math.sqrt(d_hidden)
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

    return (input_tensor, weights, config)


def clone_data(data):
    """
    Recursively goes through data and clones all tensors.
    """
    if isinstance(data, tuple):
        return tuple(clone_data(x) for x in data)
    elif isinstance(data, list):
        return [clone_data(x) for x in data]
    elif isinstance(data, dict):
        return {k: clone_data(v) for k, v in data.items()}
    elif isinstance(data, torch.Tensor):
        return data.clone()
    else:
        return data