write_page_table.py 3.49 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import triton
import triton.language as tl


def scatter_to_page_table(
    add_pages: torch.Tensor,  # [L, H] int32
    new_phys_pages: torch.Tensor,  # [N]
    curr_pages: torch.Tensor,  # [L, H] int32
    page_table: torch.Tensor,  # [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
    max_pages_per_head: int,
):
    """
    Append newly allocated physical pages into a layered page table via Triton.
    For each (layer ``l``, head ``h``):
    Args:
        :param add_pages:
            Tensor of shape ``[L, H]`` (int32) indicating how many pages to
            append for each (layer, head).
        :param new_phys_pages:
            1D tensor of shape ``[N]`` (int32) containing physical page IDs
            for all (layer, head) pairs, concatenated in row-major (L, H)
            order. ``N`` must equal ``add_pages.sum()``.
        :param curr_pages:
            Tensor of shape ``[L, H]`` (int32) with the current logical page
            counts per (layer, head) before this update.
        :param page_table:
            Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
            the logical to physical page mapping. The last dimension is
            logically indexed as logical_page ∈ [0, max_pages_per_head).
        :param max_pages_per_head:
            Maximum number of logical pages permitted per (layer, head). The
            kernel skips writes beyond this bound.
    Returns:
        None. The function updates ``page_table`` in-place.
    """
    L, H = add_pages.shape
    if L == 0 or H == 0:
        return
    add_flat = add_pages.to(torch.int32).contiguous().view(-1)
    curr_flat = curr_pages.to(torch.int32).contiguous().view(-1)
    cum_page_heads = torch.empty(L * H + 1, device="cuda", dtype=torch.int32)
    cum_page_heads[0] = 0
    torch.cumsum(add_flat, 0, out=cum_page_heads[1:])
    stride_pl, stride_ph, stride_pp = page_table.stride()
    grid = (L, H)
    _scatter_pages_kernel_lh[grid](
        add_flat,
        cum_page_heads,
        new_phys_pages,
        curr_flat,
        page_table,
        stride_pl,
        stride_ph,
        stride_pp,
        L=L,
        H=H,
        max_pages_per_head=max_pages_per_head,
    )


@triton.jit
def _scatter_pages_kernel_lh(
    add_pages,  # int32 [L*H]
    cum_page_heads,  # int32 [L*H], base offset in flat_new_phys per (l,h)
    flat_new_phys,  # int32 [total_pages]
    curr_pages,  # int32 [L*H], existing logical pages per (l,h)
    page_table_ptr,  # int32* base pointer to page_table
    stride_pl,  # int, stride for layer dim
    stride_ph,  # int, stride for head dim
    stride_pp,  # int, stride for page dim
    L: tl.constexpr,
    H: tl.constexpr,
    max_pages_per_head: tl.constexpr,
):
    layer_idx = tl.program_id(0)
    h = tl.program_id(1)
    if layer_idx >= L or h >= H:
        return

    lh = layer_idx * H + h
    ap = tl.load(add_pages + lh)
    if ap <= 0:
        return

    base = tl.load(cum_page_heads + lh)
    cp = tl.load(curr_pages + lh)

    # Append ap pages: logical pages [cp .. cp+ap)
    for i in tl.range(0, ap):
        phys = tl.load(flat_new_phys + base + i)
        lp = cp + i
        if lp < max_pages_per_head:
            offset = layer_idx * stride_pl + h * stride_ph + lp * stride_pp
            tl.store(page_table_ptr + offset, phys)


# TODO: write reclaim kernel
@triton.jit
def reclaim_page_kernel():
    pass


def reclaim_pages(
    batch_index: int,
    bh_seq_lens: torch.Tensor,
    bh_num_pages: torch.Tensor,
    page_table: torch.Tensor,
):
    pass