topk_select.py 5.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from typing import Optional

import torch


def _packed_varlen_coords(
    *,
    cu_seqlens: torch.Tensor,  # [B+1]
    total_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute packed varlen segment coordinates.

    Returns:
      starts: [B] int64, segment start offsets (inclusive)
      ends: [B] int64, segment end offsets (exclusive)
      lengths: [B] int64, segment lengths (ends - starts)
      req_ids: [T] int64, request id for each token in packed [0, T)
      pos_in_req: [T] int64, position within its request segment
    """
    device = cu_seqlens.device
    B = int(cu_seqlens.numel() - 1)
    if B <= 0:
        empty = torch.empty((0, ), device=device, dtype=torch.long)
        t_empty = torch.empty((0, ), device=device, dtype=torch.long)
        return empty, empty, empty, t_empty, t_empty

    starts = cu_seqlens[:B].to(torch.long)
    ends = cu_seqlens[1:B + 1].to(torch.long)
    lengths = ends - starts

    if total_tokens <= 0:
        t_empty = torch.empty((0, ), device=device, dtype=torch.long)
        return starts, ends, lengths, t_empty, t_empty

    token_idx = torch.arange(total_tokens, device=device, dtype=torch.long)
    req_ids = torch.bucketize(token_idx, ends, right=True)  # [T]
    start_per_token = starts.index_select(0, req_ids)
    pos_in_req = token_idx - start_per_token
    return starts, ends, lengths, req_ids, pos_in_req


def _topk_keep_mask_and_local_rank(
    *,
    token_scores: Optional[torch.Tensor],  # [T] float32
    must_keep: torch.Tensor,  # [T] bool
    topk_budget: torch.Tensor,  # [B] int32
    starts: torch.Tensor,  # [B] int64
    lengths: torch.Tensor,  # [B] int64
    req_ids: torch.Tensor,  # [T] int64
    pos_in_req: torch.Tensor,  # [T] int64
    max_len: Optional[int] = None,
    topk_budget_max: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute keep_mask/local_rank for token-shared Top-K selection.

    Returns:
      keep_mask: [T] bool, selected tokens (includes must_keep)
      local_rank: [T] int64, rank among kept tokens within each request
      keep_len: [B] int32, number of kept tokens per request
    """
    device = must_keep.device
    T = int(must_keep.numel())
    B = int(topk_budget.numel())

    keep_mask = must_keep.clone()
    if T == 0 or B == 0:
        local_rank = torch.empty((T, ), device=device, dtype=torch.long)
        keep_len = torch.zeros((B, ), device=device, dtype=torch.int32)
        return keep_mask, local_rank, keep_len

    if max_len is None:
        L_max = int(lengths.max().item()) if lengths.numel() > 0 else 0
    else:
        L_max = int(max_len)
    if L_max < 0:
        L_max = 0

    must_keep_counts = torch.zeros((B, ), device=device, dtype=torch.long)
    must_keep_counts.scatter_add_(0, req_ids, must_keep.to(torch.long))
    cand_counts = (lengths.to(torch.long) - must_keep_counts).clamp_min(0)
    k_eff = torch.minimum(topk_budget.to(torch.long).clamp_min(0), cand_counts)

    # CPU-known bound avoids a device->host sync; clamp for safety.
    if topk_budget_max is None:
        k_max = int(k_eff.max().item()) if k_eff.numel() > 0 else 0
    else:
        k_max = int(topk_budget_max)
    if k_max < 0:
        k_max = 0
    if k_max > L_max:
        k_max = L_max

    if k_max > 0:
        if token_scores is None:
            raise ValueError("token_scores must be provided when k_max > 0.")
        masked_scores = token_scores.to(torch.float32).masked_fill(
            must_keep, float("-inf"))
        scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
        linear = req_ids * L_max + pos_in_req
        scores_flat[linear] = masked_scores
        scores = scores_flat.view(B, L_max)

        topk_pos = torch.topk(scores, k=k_max, dim=1).indices  # [B, k_max]
        col_mask = torch.arange(k_max,
                                device=device).unsqueeze(0) < k_eff.unsqueeze(1)

        global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long)  # [B,k_max]
        flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
        flat_val = col_mask.reshape(-1).to(torch.int32)
        tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
        tmp.scatter_add_(0, flat_idx, flat_val)
        keep_mask |= tmp > 0

    keep_len = torch.zeros((B, ), device=device, dtype=torch.long)
    keep_len.scatter_add_(0, req_ids, keep_mask.to(torch.long))

    # Stable, order-preserving local rank using segment-local prefix sums.
    keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0)  # [T]
    start_minus_1 = (starts - 1).clamp_min(0)
    prefix_before_all = keep_prefix.index_select(0, start_minus_1)
    prefix_before = torch.where(starts > 0, prefix_before_all,
                                torch.zeros_like(prefix_before_all))  # [B]
    prefix_before_per_token = prefix_before.index_select(0, req_ids)  # [T]
    local_rank = keep_prefix - prefix_before_per_token - 1  # [T]
    return keep_mask, local_rank, keep_len.to(torch.int32)