expert_specialization.py 477 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch


def es_fp8_blockwise_scaled_grouped_mm(
    output,
    a,
    b,
    scales_a,
    scales_b,
    stride_a,
    stride_b,
    stride_d,
    problem_sizes,
    expert_offsets,
15
    workspace,
16
17
18
19
20
21
22
23
24
25
26
27
):
    torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
        output,
        a,
        b,
        scales_a,
        scales_b,
        stride_a,
        stride_b,
        stride_d,
        problem_sizes,
        expert_offsets,
28
        workspace,
29
    )