utils.py 8.87 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utility methods for model layers."""
4

5
from collections.abc import Callable
6
7
8

import torch

9
10
from vllm import _custom_ops as ops
from vllm import envs
11
from vllm._aiter_ops import rocm_aiter_ops
12
from vllm.logger import init_logger
13
from vllm.platforms import CpuArchEnum, current_platform
14
from vllm.utils.platform_utils import num_compute_units
15
from vllm.utils.torch_utils import direct_register_custom_op
16

17
18
logger = init_logger(__name__)

19
20
21
22
23
24
25
26
27
28
29
30
31
32
MOE_LAYER_ROUTER_GATE_SUFFIXES = {
    "gate",
    "router",
    "router_gate",
    "shared_expert_gate",
    "expert_gate",
}


def is_layer_moe_router_gate(prefix: str) -> bool:
    if not prefix:
        return False
    return prefix.rsplit(".", 1)[-1] in MOE_LAYER_ROUTER_GATE_SUFFIXES

33
34
35
36
37

def get_token_bin_counts_and_mask(
    tokens: torch.Tensor,
    vocab_size: int,
    num_seqs: int,
38
) -> tuple[torch.Tensor, torch.Tensor]:
39
40
    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
41
42
43
    bin_counts = torch.zeros(
        (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device
    )
44
45
46
47
48
49
50
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask


51
52
53
54
55
56
57
58
def apply_penalties(
    logits: torch.Tensor,
    prompt_tokens_tensor: torch.Tensor,
    output_tokens_tensor: torch.Tensor,
    presence_penalties: torch.Tensor,
    frequency_penalties: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> torch.Tensor:
59
60
61
    """
    Applies penalties in place to the logits tensor
    logits : The input logits tensor of shape [num_seqs, vocab_size]
62
63
64
65
    prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
        are padded to the maximum prompt length within the batch using
        `vocab_size` as the padding value. The value `vocab_size` is used
        for padding because it does not correspond to any valid token ID
66
67
68
69
70
71
72
        in the vocabulary.
    output_tokens_tensor: The output tokens tensor.
    presence_penalties: The presence penalties of shape (num_seqs, )
    frequency_penalties: The frequency penalties of shape (num_seqs, )
    repetition_penalties: The repetition penalties of shape (num_seqs, )
    """
    num_seqs, vocab_size = logits.shape
73
74
75
    _, prompt_mask = get_token_bin_counts_and_mask(
        prompt_tokens_tensor, vocab_size, num_seqs
    )
76
    output_bin_counts, output_mask = get_token_bin_counts_and_mask(
77
78
        output_tokens_tensor, vocab_size, num_seqs
    )
79

80
81
    # Apply repetition penalties as a custom op
    from vllm._custom_ops import apply_repetition_penalties
82
83

    apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties)
84

85
86
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
87
88
    logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze(dim=1) * output_mask
89
    return logits
90
91


92
93
94
95
def default_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
96
    bias: torch.Tensor | None = None,
97
):
98
99
100
    return torch.nn.functional.linear(x, weight, bias)


101
102
def use_aiter_triton_gemm(n, m, k, dtype):
    if (
103
        not rocm_aiter_ops.is_triton_gemm_enabled()
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        # MI300's - fp8nuz=True
        or current_platform.is_fp8_fnuz()
        or dtype not in [torch.float16, torch.bfloat16]
    ):
        return False

    # use hipblaslt for the larger GEMMs
    if n > 2048 and m > 512:
        return False
    return (
        (m == 5120 and k == 2880)
        or (m == 2880 and k == 4096)
        or (m == 128 and k == 2880)
        or (m == 640 and k == 2880)
        or (m == 2880 and k == 512)
    )


122
def rocm_unquantized_gemm_impl(
123
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
124
) -> torch.Tensor:
125
    from vllm.platforms.rocm import on_gfx9, on_gfx950
126

127
    n = x.numel() // x.size(-1)
128
    m = weight.shape[0]
129
    k = weight.shape[1]
130

131
    cu_count = num_compute_units()
132
133
134
135
136
    if use_aiter_triton_gemm(n, m, k, x.dtype):
        from aiter.ops.triton.gemm_a16w16 import gemm_a16w16

        return gemm_a16w16(x, weight, bias)

137
138
139
140
141
142
143
144
145
146
147
148
149
    # Next ^2 of n
    N_p2 = 1 << (n - 1).bit_length()
    # With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
    # and each working on a 512-shard of K, how many CUs would we need?
    rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
    # How many of 4 waves in a group can work on same 16 Ms at same time?
    # This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
    GrpsShrB = min(N_p2 // 16, 4)
    # Given the above, how many CUs would we need?
    CuNeeded = rndup_cus * GrpsShrB
    # candidate for atomic reduce count splitk?
    fits_wvsplitkrc = CuNeeded <= cu_count

150
151
152
153
154
    use_skinny_reduce_counting = (
        envs.VLLM_ROCM_USE_SKINNY_GEMM
        and on_gfx950()
        and x.dtype in [torch.float16, torch.bfloat16]
        and (
155
156
            10 <= n <= 128
            and k % 8 == 0
157
            and k > 512
158
159
            and m % 16 == 0
            and fits_wvsplitkrc
160
            and x.is_contiguous()
161
162
163
164
165
166
167
        )
    )
    if use_skinny_reduce_counting:
        x_view = x.reshape(-1, x.size(-1))
        out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
        return out.reshape(*x.shape[:-1], weight.shape[0])

168
169
170
171
172
173
    use_skinny = (
        envs.VLLM_ROCM_USE_SKINNY_GEMM
        and on_gfx9()
        and x.dtype in [torch.float16, torch.bfloat16]
        and k % 8 == 0
    )
174
175
176
177

    if use_skinny is not True:
        return torch.nn.functional.linear(x, weight, bias)

178
    x_view = x.reshape(-1, x.size(-1))
179
    if m > 8 and 0 < n <= 4:
180
        cu_count = num_compute_units()
181
        out = ops.wvSplitK(weight, x_view, cu_count, bias)
182
        return out.reshape(*x.shape[:-1], weight.shape[0])
183
    elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
184
        out = ops.LLMM1(weight, x_view, 4)
185
        return out.reshape(*x.shape[:-1], weight.shape[0])
186
187
188
    return torch.nn.functional.linear(x, weight, bias)


189
def rocm_unquantized_gemm_fake(
190
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
191
) -> torch.Tensor:
192
193
194
    return x.new_empty((*x.shape[:-1], weight.shape[0]))


195
196
197
198
def rocm_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
199
    bias: torch.Tensor | None = None,
200
) -> torch.Tensor:
201
    return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)
202
203
204


direct_register_custom_op(
205
    op_name="rocm_unquantized_gemm",
206
    op_func=rocm_unquantized_gemm_impl,
207
    fake_impl=rocm_unquantized_gemm_fake,
208
209
210
)


211
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
212
213
214
215
216
217
    return (
        torch._C._cpu._is_amx_tile_supported()
        and (dtype in (torch.bfloat16, torch.int8))
        and k % 32 == 0
        and n % 16 == 0
    )
218
219


220
221
222
223
def dispatch_cpu_unquantized_gemm(
    layer: torch.nn.Module,
    remove_weight: bool,
) -> None:
224
225
226
227
228
    # skip for missing layers
    if layer.weight.is_meta:
        layer.cpu_linear = torch.nn.functional.linear
        return

229
230
    N, K = layer.weight.size()
    dtype = layer.weight.dtype
231

232
233
234
235
236
237
    if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
        packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
        if getattr(layer, "bias", None) is not None:
            bias_f32 = layer.bias.to(torch.float32)
        else:
            bias_f32 = None
238
239
240
        layer.cpu_linear = lambda x, weight, bias: torch.ops._C.weight_packed_linear(
            x, packed_weight, bias_f32 if bias is not None else None, True
        )
241
        if remove_weight:
242
            layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
243
        return
244
245
246
247
    elif (
        ops._supports_onednn
        and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC
    ):
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        try:
            origin_weight = layer.weight
            handler = ops.create_onednn_mm(origin_weight.t(), 32)
            layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
            if remove_weight:
                layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
            return
        except RuntimeError as e:
            logger.warning_once(
                "Failed to create oneDNN linear, fallback to torch linear."
                f" Exception: {e}"
            )

    # fallback case
    layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
        x, weight, bias
    )
265
266


267
268
269
270
def cpu_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
271
    bias: torch.Tensor | None = None,
272
):
273
    return layer.cpu_linear(x, weight, bias)
274
275


276
277
278
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
    if current_platform.is_rocm():
        return rocm_unquantized_gemm
279
280
281
282
    elif current_platform.is_cpu():
        return cpu_unquantized_gemm
    else:
        return default_unquantized_gemm