"vscode:/vscode.git/clone" did not exist on "e302950da3bcd8c6bbdf4ac3897282decedb83e4"
marlin.py 1006 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch


def gptq_marlin_repack(
    b_q_weight,
    perm,
    size_k,
    size_n,
    num_bits,
):
    torch.ops.sgl_kernel.gptq_marlin_repack.default(
        b_q_weight,
        perm,
        size_k,
        size_n,
        num_bits,
    )


def awq_marlin_repack(
    b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
    return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


def awq_marlin_moe_repack(
    b_q_weight: torch.Tensor,
    perm: torch.Tensor,
    size_k: int,
    size_n: int,
    num_bits: int,
) -> torch.Tensor:
    num_experts = b_q_weight.shape[0]
    assert size_k % 16 == 0
    output = torch.empty(
        (num_experts, size_k // 16, size_n * (num_bits // 2)),
        device=b_q_weight.device,
        dtype=b_q_weight.dtype,
    )
    for e in range(num_experts):
        output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
            b_q_weight[e], size_k, size_n, num_bits
        )
    return output