gemm.py 3.48 KB
Newer Older
1
2
3
from typing import List, Optional

import torch
4
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
5
6


7
8
9
def awq_dequantize(
    qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.ByteTensor:
Yineng Zhang's avatar
Yineng Zhang committed
10
    return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros)
11
12


13
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
14
    return torch.ops.sgl_kernel.int8_scaled_mm(
15
16
17
18
19
20
21
22
23
24
        mat_a,
        mat_b,
        scales_a,
        scales_b,
        out_dtype,
        bias,
    )


def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
25
    return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
26
27
28
29
30
31
32
33
34
        mat_a,
        mat_b,
        scales_a,
        scales_b,
        out_dtype,
    )


def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
35
    return torch.ops.sgl_kernel.fp8_scaled_mm(
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        mat_a,
        mat_b,
        scales_a,
        scales_b,
        out_dtype,
        bias,
    )


def _bmm_fp8_internal(
    workspace_buffer: torch.Tensor,
    A: torch.Tensor,
    B: torch.Tensor,
    D: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
) -> None:
    cublas_handle = torch.cuda.current_blas_handle()
54
    torch.ops.sgl_kernel.bmm_fp8(
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        A,
        B,
        D,
        A_scale,
        B_scale,
        workspace_buffer,
        cublas_handle,
        get_cuda_stream(),
    )


def bmm_fp8(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    dtype: torch.dtype,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if out is None:
        out = torch.empty(
            (A.shape[0], A.shape[1], B.shape[2]),
            device=A.device,
            dtype=dtype,
        )
    workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
    _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
    return out


def sgl_per_token_group_quant_fp8(
    input: torch.Tensor,
    output_q: torch.Tensor,
    output_s: torch.Tensor,
    group_size: int,
    eps: float,
    fp8_min: float,
    fp8_max: float,
) -> None:
94
    torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
95
96
97
98
        input, output_q, output_s, group_size, eps, fp8_min, fp8_max
    )


99
100
101
102
103
104
105
106
107
108
109
110
111
112
def sgl_per_token_group_quant_int8(
    input: torch.Tensor,
    output_q: torch.Tensor,
    output_s: torch.Tensor,
    group_size: int,
    eps: float,
    int8_min: float,
    int8_max: float,
) -> None:
    torch.ops.sgl_kernel.sgl_per_token_group_quant_int8(
        input, output_q, output_s, group_size, eps, int8_min, int8_max
    )


113
114
115
116
117
118
def sgl_per_tensor_quant_fp8(
    input: torch.Tensor,
    output_q: torch.Tensor,
    output_s: torch.Tensor,
    is_static: bool,
) -> None:
119
    torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
120
121


122
123
124
125
126
127
128
129
130
131
def cublas_grouped_gemm(
    inputs: List[torch.Tensor],
    weights: List[torch.Tensor],
    outputs: List[torch.Tensor],
    out_dtype: torch.dtype,
) -> None:
    assert (
        len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
    ), "Inputs/weights/outputs should not be empty!"
    cublas_handle = torch.cuda.current_blas_handle()
132
    torch.ops.sgl_kernel.cublas_grouped_gemm(
133
134
135
136
137
138
139
        inputs,
        weights,
        outputs,
        out_dtype,
        cublas_handle,
        get_cuda_stream(),
    )
140
141
142
143
144
145
146


def sgl_per_token_quant_fp8(
    input: torch.Tensor,
    output_q: torch.Tensor,
    output_s: torch.Tensor,
) -> None:
147
    torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)