utils.py 8.69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utility methods for model layers."""
4

5
from collections.abc import Callable
6
7
8

import torch

9
10
from vllm import _custom_ops as ops
from vllm import envs
11
from vllm._aiter_ops import rocm_aiter_ops
12
from vllm.logger import init_logger
13
from vllm.platforms import CpuArchEnum, current_platform
14
from vllm.utils.platform_utils import get_cu_count
15
from vllm.utils.torch_utils import direct_register_custom_op
16

17
logger = init_logger(__name__)
18

19

20
21
22
23
24
25
26
27
28
29
30
31
32
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
    # Shuffle weight along the last dimension so that
    # we folded the weights to adjance location
    # Example:
    # input:
    #       [[1, 2, 3, 4, 5, 6],
    #        [7, 8, 9, 10, 11, 12]]
    # output:
    #       [[1, 4, 2, 5, 3, 6],
    #        [7, 10, 8, 11, 9, 12]]
    # This will be used together with triton swiglu kernel
    shape = w.shape
    N = shape[-1]
33
34
    first = w[..., : N // 2]
    second = w[..., N // 2 :]
35
36
37
38

    stacked = torch.stack((first, second), dim=-1)
    w_shuffled = stacked.reshape(shape)
    return w_shuffled
39

40
41
42
43
44

def get_token_bin_counts_and_mask(
    tokens: torch.Tensor,
    vocab_size: int,
    num_seqs: int,
45
) -> tuple[torch.Tensor, torch.Tensor]:
46
47
    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
48
49
50
    bin_counts = torch.zeros(
        (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device
    )
51
52
53
54
55
56
57
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask


58
59
60
61
62
63
64
65
def apply_penalties(
    logits: torch.Tensor,
    prompt_tokens_tensor: torch.Tensor,
    output_tokens_tensor: torch.Tensor,
    presence_penalties: torch.Tensor,
    frequency_penalties: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> torch.Tensor:
66
67
68
    """
    Applies penalties in place to the logits tensor
    logits : The input logits tensor of shape [num_seqs, vocab_size]
69
70
71
72
    prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
        are padded to the maximum prompt length within the batch using
        `vocab_size` as the padding value. The value `vocab_size` is used
        for padding because it does not correspond to any valid token ID
73
74
75
76
77
78
79
        in the vocabulary.
    output_tokens_tensor: The output tokens tensor.
    presence_penalties: The presence penalties of shape (num_seqs, )
    frequency_penalties: The frequency penalties of shape (num_seqs, )
    repetition_penalties: The repetition penalties of shape (num_seqs, )
    """
    num_seqs, vocab_size = logits.shape
80
81
82
    _, prompt_mask = get_token_bin_counts_and_mask(
        prompt_tokens_tensor, vocab_size, num_seqs
    )
83
    output_bin_counts, output_mask = get_token_bin_counts_and_mask(
84
85
        output_tokens_tensor, vocab_size, num_seqs
    )
86

87
88
    # Apply repetition penalties as a custom op
    from vllm._custom_ops import apply_repetition_penalties
89
90

    apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties)
91

92
93
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
94
95
    logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze(dim=1) * output_mask
96
    return logits
97
98


99
100
101
102
def default_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
103
    bias: torch.Tensor | None = None,
104
):
105
106
107
    return torch.nn.functional.linear(x, weight, bias)


108
109
def use_aiter_triton_gemm(n, m, k, dtype):
    if (
110
        not rocm_aiter_ops.is_triton_gemm_enabled()
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        # MI300's - fp8nuz=True
        or current_platform.is_fp8_fnuz()
        or dtype not in [torch.float16, torch.bfloat16]
    ):
        return False

    # use hipblaslt for the larger GEMMs
    if n > 2048 and m > 512:
        return False
    return (
        (m == 5120 and k == 2880)
        or (m == 2880 and k == 4096)
        or (m == 128 and k == 2880)
        or (m == 640 and k == 2880)
        or (m == 2880 and k == 512)
    )


129
def rocm_unquantized_gemm_impl(
130
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
131
) -> torch.Tensor:
132
    from vllm.platforms.rocm import on_gfx9, on_gfx950
133

134
135
    n = x.numel() / x.size(-1)
    m = weight.shape[0]
136
    k = weight.shape[1]
137

138
139
140
141
142
143
144
145
146
147
148
    import math

    use_skinny_reduce_counting = (
        envs.VLLM_ROCM_USE_SKINNY_GEMM
        and on_gfx950()
        and x.dtype in [torch.float16, torch.bfloat16]
        and (
            n >= 16
            and n <= 128
            and k > 512
            and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
149
            and x.is_contiguous()
150
151
152
153
154
155
156
157
158
        )
        # k == 2880 and (m == 640 or m == 128))
    )
    if use_skinny_reduce_counting:
        cu_count = get_cu_count()
        x_view = x.reshape(-1, x.size(-1))
        out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
        return out.reshape(*x.shape[:-1], weight.shape[0])

159
160
161
162
163
    if use_aiter_triton_gemm(n, m, k, x.dtype):
        from aiter.ops.triton.gemm_a16w16 import gemm_a16w16

        return gemm_a16w16(x, weight, bias)

164
165
166
167
168
    use_skinny = (
        envs.VLLM_ROCM_USE_SKINNY_GEMM
        and on_gfx9()
        and x.dtype in [torch.float16, torch.bfloat16]
        and k % 8 == 0
169
        and x.is_contiguous()
170
    )
171
172
173
174

    if use_skinny is not True:
        return torch.nn.functional.linear(x, weight, bias)

175
    x_view = x.reshape(-1, x.size(-1))
176
    if m > 8 and 0 < n <= 4:
177
        cu_count = get_cu_count()
178
        out = ops.wvSplitK(weight, x_view, cu_count, bias)
179
        return out.reshape(*x.shape[:-1], weight.shape[0])
180
    elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
181
        out = ops.LLMM1(weight, x_view, 4)
182
        return out.reshape(*x.shape[:-1], weight.shape[0])
183
184
185
    return torch.nn.functional.linear(x, weight, bias)


186
def rocm_unquantized_gemm_fake(
187
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
188
) -> torch.Tensor:
189
190
191
    return x.new_empty((*x.shape[:-1], weight.shape[0]))


192
193
194
195
def rocm_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
196
    bias: torch.Tensor | None = None,
197
) -> torch.Tensor:
198
    return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)
199
200
201


direct_register_custom_op(
202
    op_name="rocm_unquantized_gemm",
203
    op_func=rocm_unquantized_gemm_impl,
204
    fake_impl=rocm_unquantized_gemm_fake,
205
206
207
)


208
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
209
210
211
212
213
214
    return (
        torch._C._cpu._is_amx_tile_supported()
        and (dtype in (torch.bfloat16, torch.int8))
        and k % 32 == 0
        and n % 16 == 0
    )
215
216


217
218
219
220
def dispatch_cpu_unquantized_gemm(
    layer: torch.nn.Module,
    remove_weight: bool,
) -> None:
221
222
223
224
225
    # skip for missing layers
    if layer.weight.is_meta:
        layer.cpu_linear = torch.nn.functional.linear
        return

226
227
    N, K = layer.weight.size()
    dtype = layer.weight.dtype
228

229
230
231
232
233
234
    if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
        packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
        if getattr(layer, "bias", None) is not None:
            bias_f32 = layer.bias.to(torch.float32)
        else:
            bias_f32 = None
235
236
237
        layer.cpu_linear = lambda x, weight, bias: torch.ops._C.weight_packed_linear(
            x, packed_weight, bias_f32 if bias is not None else None, True
        )
238
        if remove_weight:
239
            layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
240
        return
241
242
243
244
    elif (
        ops._supports_onednn
        and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC
    ):
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        try:
            origin_weight = layer.weight
            handler = ops.create_onednn_mm(origin_weight.t(), 32)
            layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
            if remove_weight:
                layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
            return
        except RuntimeError as e:
            logger.warning_once(
                "Failed to create oneDNN linear, fallback to torch linear."
                f" Exception: {e}"
            )

    # fallback case
    layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
        x, weight, bias
    )
262
263


264
265
266
267
def cpu_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
268
    bias: torch.Tensor | None = None,
269
):
270
    return layer.cpu_linear(x, weight, bias)
271
272


273
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
zhuwenwen's avatar
zhuwenwen committed
274
    # if current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
275
        # return rocm_unquantized_gemm
zhuwenwen's avatar
zhuwenwen committed
276
277
        # return torch.nn.functional.linear
    if current_platform.is_cpu():
278
279
280
        return cpu_unquantized_gemm
    else:
        return default_unquantized_gemm