router.py 7.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Fused functions used in the MoE router
"""
import torch
import transformer_engine_torch as tex


class FusedTopkScoreFunction(torch.autograd.Function):
    """
    Fused Topk with Score Function router.
    Currently, only support softmax and sigmoid.
    """

    @staticmethod
    def forward(
        ctx,
        logits: torch.Tensor,
        topk: int,
        use_pre_softmax: bool,
        num_groups: int,
        group_topk: int,
        scaling_factor: float,
        score_function: str,
        expert_bias: torch.Tensor,
    ):
        # pylint: disable=missing-function-docstring
30
31
32
33
34
35
        # Save the shape of the logits
        tensor_shape = logits.shape
        logits = logits.view(-1, tensor_shape[-1])
        # Get the metadata of the viewed logits
        num_tokens = logits.size(0)
        num_experts = logits.size(1)
36
37
38
39
40
41
42
43
44
45
        probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd(
            logits,
            topk,
            use_pre_softmax,
            num_groups,
            group_topk,
            scaling_factor,
            score_function,
            expert_bias,
        )
46
47
        # Restore the shape
        probs = probs.view(tensor_shape)
48
        ctx.save_for_backward(routing_map, intermediate_output)
49
50
        ctx.num_tokens = num_tokens
        ctx.num_experts = num_experts
51
52
53
54
55
56
57
58
59
60
        ctx.use_pre_softmax = use_pre_softmax
        ctx.topk = topk
        ctx.scaling_factor = scaling_factor
        ctx.score_function = score_function
        return probs, routing_map

    @staticmethod
    def backward(ctx, grad_probs, _):
        # pylint: disable=missing-function-docstring
        routing_map, intermediate_output = ctx.saved_tensors
61
62
63
64
        # Save the shape of the grad_probs
        tensor_shape = grad_probs.shape
        # Adjust the shape of the grad_probs to 2D shape
        grad_probs = grad_probs.contiguous().view(-1, tensor_shape[-1])
65
66
67
68
69
        grad_logits = tex.fused_topk_with_score_function_bwd(
            ctx.num_tokens,
            ctx.num_experts,
            routing_map,
            intermediate_output,
70
            grad_probs,
71
72
73
74
75
            ctx.topk,
            ctx.use_pre_softmax,
            ctx.scaling_factor,
            ctx.score_function,
        )
76
77
        # Restore the shape
        grad_logits = grad_logits.view(tensor_shape)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        return grad_logits, None, None, None, None, None, None, None


def fused_topk_with_score_function(
    logits: torch.Tensor,
    topk: int,
    use_pre_softmax: bool,
    num_groups: int,
    group_topk: int,
    scaling_factor: float,
    score_function: str,
    expert_bias: torch.Tensor,
):
    """
    Fused topk with score function router.
    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
95
96
97
    logits : torch.Tensor
    topk : int
    use_pre_softmax : bool
98
        if enabled, the computation order: softmax -> topk
Paweł Gadziński's avatar
Paweł Gadziński committed
99
    num_groups : int
100
        used in the group topk
Paweł Gadziński's avatar
Paweł Gadziński committed
101
    group_topk : int
102
        used in the group topk
Paweł Gadziński's avatar
Paweł Gadziński committed
103
104
    scaling_factor : float
    score_function : str
105
        currently only support softmax and sigmoid
Paweł Gadziński's avatar
Paweł Gadziński committed
106
    expert_bias : torch.Tensor
107
108
109
110
        could be used in the sigmoid

    Returns
    -------
Paweł Gadziński's avatar
Paweł Gadziński committed
111
112
    probs : torch.Tensor
    routing_map : torch.Tensor
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    """
    if logits.dtype == torch.float64:
        raise ValueError("Current TE does not support float64 router type")
    return FusedTopkScoreFunction.apply(
        logits,
        topk,
        use_pre_softmax,
        num_groups,
        group_topk,
        scaling_factor,
        score_function,
        expert_bias,
    )


class FusedComputeScoresForMoEAuxLoss(torch.autograd.Function):
    """
    Fused compute scores for MoE aux loss.
    """

    @staticmethod
    def forward(
        ctx,
        logits: torch.Tensor,
        topk: int,
        score_function: str,
    ):
        # pylint: disable=missing-function-docstring
141
142
143
144
145
146
        # Save the shape of the logits
        tensor_shape = logits.shape
        logits = logits.view(-1, tensor_shape[-1])
        # Get the metadata of the viewed logits
        num_tokens = logits.size(0)
        num_experts = logits.size(1)
147
148
149
150
151
152
153
154
        scores, routing_map, intermediate_output = tex.fused_score_for_moe_aux_loss_fwd(
            logits=logits,
            topk=topk,
            score_function=score_function,
        )
        ctx.save_for_backward(intermediate_output)
        ctx.topk = topk
        ctx.score_function = score_function
155
156
        ctx.num_tokens = num_tokens
        ctx.num_experts = num_experts
157
158
159
160
161
162
        return routing_map, scores

    @staticmethod
    def backward(ctx, _, grad_scores):
        # pylint: disable=missing-function-docstring
        intermediate_output = ctx.saved_tensors[0]
163
164
165
166
        # Save the shape of the grad_scores
        tensor_shape = grad_scores.shape
        # Adjust the shape of the grad_scores to 2D shape
        grad_scores = grad_scores.contiguous().view(-1, tensor_shape[-1])
167
168
169
170
        grad_logits = tex.fused_score_for_moe_aux_loss_bwd(
            num_tokens=ctx.num_tokens,
            num_experts=ctx.num_experts,
            intermediate_output=intermediate_output,
171
            grad_scores=grad_scores,
172
173
174
            topk=ctx.topk,
            score_function=ctx.score_function,
        )
175
176
        # Restore the shape
        grad_logits = grad_logits.view(tensor_shape)
177
178
179
180
181
182
183
184
185
186
187
188
        return grad_logits, None, None


def fused_compute_score_for_moe_aux_loss(
    logits: torch.Tensor,
    topk: int,
    score_function: str,
):
    """
    Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
189
190
191
    logits : torch.Tensor
    topk : int
    score_function : str
192
193
194
195
        currently only support softmax and sigmoid

    Returns
    -------
Paweł Gadziński's avatar
Paweł Gadziński committed
196
197
    routing_map : torch.Tensor
    scores : torch.Tensor
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    """
    return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function)


class FusedAuxLoss(torch.autograd.Function):
    """
    Fused MoE aux loss.
    """

    @staticmethod
    def forward(
        ctx,
        probs: torch.Tensor,
        tokens_per_expert: torch.Tensor,
        total_num_tokens: int,
        num_experts: int,
        topk: int,
        coeff: float,
    ):
        # pylint: disable=missing-function-docstring
218
219
        num_rows = probs.size(0)
        num_cols = probs.size(1)
220
221
222
223
224
        aux_loss, Const_buf = tex.fused_moe_aux_loss_fwd(
            probs=probs,
            tokens_per_expert=tokens_per_expert,
            total_num_tokens=total_num_tokens,
            num_experts=num_experts,
225
226
            num_rows=num_rows,
            num_cols=num_cols,
227
228
229
230
            topk=topk,
            coeff=coeff,
        )
        ctx.save_for_backward(Const_buf, tokens_per_expert)
231
232
        ctx.num_rows = num_rows
        ctx.num_cols = num_cols
233
234
235
236
237
238
239
240
241
        return aux_loss

    @staticmethod
    def backward(ctx, grad_aux_loss):
        # pylint: disable=missing-function-docstring
        Const_buf, tokens_per_expert = ctx.saved_tensors
        grad_probs = tex.fused_moe_aux_loss_bwd(
            Const_buf=Const_buf,
            tokens_per_expert=tokens_per_expert,
242
243
            num_rows=ctx.num_rows,
            num_cols=ctx.num_cols,
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            grad_aux_loss=grad_aux_loss,
        )
        return grad_probs, None, None, None, None, None


def fused_moe_aux_loss(
    probs: torch.Tensor,
    tokens_per_expert: torch.Tensor,
    total_num_tokens: int,
    num_experts: int,
    topk: int,
    coeff: float,
):
    """
    Fused MoE aux loss.
    Parameters
    ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
261
262
    probs : torch.Tensor
    tokens_per_expert : torch.Tensor
263
        the number of tokens per expert
Paweł Gadziński's avatar
Paweł Gadziński committed
264
    total_num_tokens : int
265
        the total number of tokens, involved in the aux loss calculation
Paweł Gadziński's avatar
Paweł Gadziński committed
266
267
268
    num_experts : int
    topk : int
    coeff : float
269
270
271
272
        the coefficient of the aux loss

    Returns
    -------
Paweł Gadziński's avatar
Paweł Gadziński committed
273
    aux_loss : torch.scalar
274
275
    """
    return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff)