utils.py 6.77 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utility methods for model layers."""
4

5
from collections.abc import Callable
6
7
8

import torch

9
10
from vllm import _custom_ops as ops
from vllm import envs
11
from vllm.platforms import CpuArchEnum, current_platform
12
from vllm.utils import direct_register_custom_op
13

14

15
16
17
18
19
20
21
22
23
24
25
26
27
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
    # Shuffle weight along the last dimension so that
    # we folded the weights to adjance location
    # Example:
    # input:
    #       [[1, 2, 3, 4, 5, 6],
    #        [7, 8, 9, 10, 11, 12]]
    # output:
    #       [[1, 4, 2, 5, 3, 6],
    #        [7, 10, 8, 11, 9, 12]]
    # This will be used together with triton swiglu kernel
    shape = w.shape
    N = shape[-1]
28
29
    first = w[..., : N // 2]
    second = w[..., N // 2 :]
30
31
32
33
34
35

    stacked = torch.stack((first, second), dim=-1)
    w_shuffled = stacked.reshape(shape)
    return w_shuffled


36
37
38
39
def get_token_bin_counts_and_mask(
    tokens: torch.Tensor,
    vocab_size: int,
    num_seqs: int,
40
) -> tuple[torch.Tensor, torch.Tensor]:
41
42
    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
43
44
45
    bin_counts = torch.zeros(
        (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device
    )
46
47
48
49
50
51
52
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask


53
54
55
56
57
58
59
60
def apply_penalties(
    logits: torch.Tensor,
    prompt_tokens_tensor: torch.Tensor,
    output_tokens_tensor: torch.Tensor,
    presence_penalties: torch.Tensor,
    frequency_penalties: torch.Tensor,
    repetition_penalties: torch.Tensor,
) -> torch.Tensor:
61
62
63
    """
    Applies penalties in place to the logits tensor
    logits : The input logits tensor of shape [num_seqs, vocab_size]
64
65
66
67
    prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
        are padded to the maximum prompt length within the batch using
        `vocab_size` as the padding value. The value `vocab_size` is used
        for padding because it does not correspond to any valid token ID
68
69
70
71
72
73
74
        in the vocabulary.
    output_tokens_tensor: The output tokens tensor.
    presence_penalties: The presence penalties of shape (num_seqs, )
    frequency_penalties: The frequency penalties of shape (num_seqs, )
    repetition_penalties: The repetition penalties of shape (num_seqs, )
    """
    num_seqs, vocab_size = logits.shape
75
76
77
    _, prompt_mask = get_token_bin_counts_and_mask(
        prompt_tokens_tensor, vocab_size, num_seqs
    )
78
    output_bin_counts, output_mask = get_token_bin_counts_and_mask(
79
80
        output_tokens_tensor, vocab_size, num_seqs
    )
81

82
83
    # Apply repetition penalties as a custom op
    from vllm._custom_ops import apply_repetition_penalties
84
85

    apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties)
86

87
88
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
89
90
    logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze(dim=1) * output_mask
91
    return logits
92
93


94
95
96
97
def default_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
98
    bias: torch.Tensor | None = None,
99
):
100
101
102
    return torch.nn.functional.linear(x, weight, bias)


103
def rocm_unquantized_gemm_impl(
104
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
105
) -> torch.Tensor:
106
    from vllm.platforms.rocm import on_gfx9
107

108
    k = weight.shape[1]
109
110
111
112
113
114
    use_skinny = (
        envs.VLLM_ROCM_USE_SKINNY_GEMM
        and on_gfx9()
        and x.dtype in [torch.float16, torch.bfloat16]
        and k % 8 == 0
    )
115
116
117
118
119
120
121
122
123

    if use_skinny is not True:
        return torch.nn.functional.linear(x, weight, bias)

    x_view = x.view(-1, x.size(-1))
    n = x_view.shape[0]
    m = weight.shape[0]
    cu_count = current_platform.get_cu_count()

124
    if m > 8 and 0 < n <= 4:
125
        out = ops.wvSplitK(weight, x_view, cu_count, bias)
126
        return out.view(*x.shape[:-1], weight.shape[0])
127
    elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
128
        out = ops.LLMM1(weight, x_view, 4)
129
130
131
132
        return out.view(*x.shape[:-1], weight.shape[0])
    return torch.nn.functional.linear(x, weight, bias)


133
def rocm_unquantized_gemm_impl_fake(
134
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
135
) -> torch.Tensor:
136
137
138
    return x.new_empty((*x.shape[:-1], weight.shape[0]))


139
140
141
142
def rocm_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
143
    bias: torch.Tensor | None = None,
144
) -> torch.Tensor:
145
146
147
148
149
150
151
152
153
154
    return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)


direct_register_custom_op(
    op_name="rocm_unquantized_gemm_impl",
    op_func=rocm_unquantized_gemm_impl,
    fake_impl=rocm_unquantized_gemm_impl_fake,
)


155
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
156
157
158
159
160
161
    return (
        torch._C._cpu._is_amx_tile_supported()
        and (dtype in (torch.bfloat16, torch.int8))
        and k % 32 == 0
        and n % 16 == 0
    )
162
163


164
165
166
167
168
169
170
171
172
173
174
175
def dispatch_cpu_unquantized_gemm(
    layer: torch.nn.Module,
    remove_weight: bool,
) -> None:
    N, K = layer.weight.size()
    dtype = layer.weight.dtype
    if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
        packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
        if getattr(layer, "bias", None) is not None:
            bias_f32 = layer.bias.to(torch.float32)
        else:
            bias_f32 = None
176
177
178
        layer.cpu_linear = lambda x, weight, bias: torch.ops._C.weight_packed_linear(
            x, packed_weight, bias_f32 if bias is not None else None, True
        )
179
        if remove_weight:
180
181
182
183
184
            layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
    elif ops._supports_onednn and (
        current_platform.get_cpu_architecture() == CpuArchEnum.X86
        or ops.is_onednn_acl_supported()
    ):
185
186
        origin_weight = layer.weight
        if remove_weight:
187
            layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
188
        handler = ops.create_onednn_mm(origin_weight.t(), 32)
189
        layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
190
191
    else:
        layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
192
193
            x, weight, bias
        )
194
195


196
197
198
199
def cpu_unquantized_gemm(
    layer: torch.nn.Module,
    x: torch.Tensor,
    weight: torch.Tensor,
200
    bias: torch.Tensor | None = None,
201
):
202
    return layer.cpu_linear(x, weight, bias)
203
204


205
206
207
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
    if current_platform.is_rocm():
        return rocm_unquantized_gemm
208
209
210
211
    elif current_platform.is_cpu():
        return cpu_unquantized_gemm
    else:
        return default_unquantized_gemm