rocm_aiter_mla.py 3.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

from typing import Optional

import torch

8
9
10
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

11
12
13
14
15
16
17
18
19
20
21
22
23

def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
                           max_block_per_batch: int,
                           device: torch.device) -> tuple[torch.Tensor, ...]:
    paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
                                   dtype=torch.int32,
                                   device=device)
    paged_kv_indptr = torch.zeros(max_batch_size + 1,
                                  dtype=torch.int32,
                                  device=device)
    paged_kv_last_page_lens = torch.full((max_batch_size, ),
                                         block_size,
                                         dtype=torch.int32)
24
25
    qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
    return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
26
27
28
29
30
31
32


def aiter_mla_decode_fwd(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    sm_scale: float,
33
34
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
35
36
37
38
39
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    logit_cap: float = 0.0,
):
40
41
42
43
44

    torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
                                             kv_buffer.view(
                                                 -1, 1, 1, q.shape[-1]),
                                             o,
45
46
                                             qo_indptr,
                                             max_seqlen_qo,
47
48
49
50
51
52
53
54
55
56
57
                                             kv_indptr,
                                             kv_indices,
                                             kv_last_page_lens,
                                             sm_scale=sm_scale,
                                             logit_cap=logit_cap)


def mla_decode_fwd_impl(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
58
59
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
60
61
62
63
64
65
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None:
66
67
68
69
70
    from aiter.mla import mla_decode_fwd

    mla_decode_fwd(q,
                   kv_buffer.view(-1, 1, 1, q.shape[-1]),
                   o,
71
                   qo_indptr,
72
73
74
                   kv_indptr,
                   kv_indices,
                   kv_last_page_lens,
75
                   max_seqlen_qo,
76
77
                   sm_scale=sm_scale,
                   logit_cap=logit_cap)
78
79
80
81
82
83


def mla_decode_fwd_fake(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
84
85
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None:
    pass


if current_platform.is_rocm():
    direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
                              op_func=mla_decode_fwd_impl,
                              mutates_args=["o"],
                              fake_impl=mla_decode_fwd_fake,
                              tags=[torch.Tag.needs_fixed_stride_order])