moe.py 2.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch


def moe_align_block_size(
    topk_ids,
    num_experts,
    block_size,
    sorted_token_ids,
    experts_ids,
    num_tokens_post_pad,
    token_cnts_buffer,
    cumsum_buffer,
):
14
    torch.ops.sgl_kernel.moe_align_block_size.default(
15
16
17
18
19
20
21
22
23
        topk_ids,
        num_experts,
        block_size,
        sorted_token_ids,
        experts_ids,
        num_tokens_post_pad,
        token_cnts_buffer,
        cumsum_buffer,
    )
24
25
26
27
28
29
30
31


def topk_softmax(
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    token_expert_indices: torch.Tensor,
    gating_output: float,
) -> None:
32
    torch.ops.sgl_kernel.topk_softmax.default(
33
34
        topk_weights, topk_ids, token_expert_indices, gating_output
    )
35
36


37
38
39
40
41
42
43
44
45
def moe_fused_gate(
    input_tensor,
    bias,
    num_expert_group,
    topk_group,
    topk,
    n_share_experts_fusion=0,
    routed_scaling_factor=0,
):
46
47
    # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
    # it split group of expert into num_expert_group, and use top2 expert weight sum in each group
48
    # as the group weight to select expert groups and then select topk experts within the selected groups
49
    # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
50
51
    # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
    # for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
52
53
    # n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
    # routed_scaling_factor: if > 0, the last expert will be scaled by this factor
54
    return torch.ops.sgl_kernel.moe_fused_gate.default(
55
56
57
58
59
60
61
        input_tensor,
        bias,
        num_expert_group,
        topk_group,
        topk,
        n_share_experts_fusion,
        routed_scaling_factor,
62
    )
63
64
65
66


def fp8_blockwise_scaled_grouped_mm(
    output,
67
68
69
70
71
    a_ptrs,
    b_ptrs,
    out_ptrs,
    a_scales_ptrs,
    b_scales_ptrs,
72
73
74
75
76
77
78
79
80
81
82
    a,
    b,
    scales_a,
    scales_b,
    stride_a,
    stride_b,
    stride_c,
    layout_sfa,
    layout_sfb,
    problem_sizes,
    expert_offsets,
83
    workspace,
84
85
86
):
    torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
        output,
87
88
89
90
91
        a_ptrs,
        b_ptrs,
        out_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
92
93
94
95
96
97
98
99
100
101
102
        a,
        b,
        scales_a,
        scales_b,
        stride_a,
        stride_b,
        stride_c,
        layout_sfa,
        layout_sfb,
        problem_sizes,
        expert_offsets,
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        workspace,
    )


def prepare_moe_input(
    topk_ids,
    expert_offsets,
    problem_sizes1,
    problem_sizes2,
    input_permutation,
    output_permutation,
    num_experts,
    n,
    k,
):
    torch.ops.sgl_kernel.prepare_moe_input.default(
        topk_ids,
        expert_offsets,
        problem_sizes1,
        problem_sizes2,
        input_permutation,
        output_permutation,
        num_experts,
        n,
        k,
128
    )