# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations from typing import Optional import torch def _packed_varlen_coords( *, cu_seqlens: torch.Tensor, # [B+1] total_tokens: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute packed varlen segment coordinates. Returns: starts: [B] int64, segment start offsets (inclusive) ends: [B] int64, segment end offsets (exclusive) lengths: [B] int64, segment lengths (ends - starts) req_ids: [T] int64, request id for each token in packed [0, T) pos_in_req: [T] int64, position within its request segment """ device = cu_seqlens.device B = int(cu_seqlens.numel() - 1) if B <= 0: empty = torch.empty((0, ), device=device, dtype=torch.long) t_empty = torch.empty((0, ), device=device, dtype=torch.long) return empty, empty, empty, t_empty, t_empty starts = cu_seqlens[:B].to(torch.long) ends = cu_seqlens[1:B + 1].to(torch.long) lengths = ends - starts if total_tokens <= 0: t_empty = torch.empty((0, ), device=device, dtype=torch.long) return starts, ends, lengths, t_empty, t_empty token_idx = torch.arange(total_tokens, device=device, dtype=torch.long) req_ids = torch.bucketize(token_idx, ends, right=True) # [T] start_per_token = starts.index_select(0, req_ids) pos_in_req = token_idx - start_per_token return starts, ends, lengths, req_ids, pos_in_req def _topk_keep_mask_and_local_rank( *, token_scores: Optional[torch.Tensor], # [T] float32 must_keep: torch.Tensor, # [T] bool topk_budget: torch.Tensor, # [B] int32 starts: torch.Tensor, # [B] int64 lengths: torch.Tensor, # [B] int64 req_ids: torch.Tensor, # [T] int64 pos_in_req: torch.Tensor, # [T] int64 max_len: Optional[int] = None, topk_budget_max: Optional[int] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute keep_mask/local_rank for token-shared Top-K selection. Returns: keep_mask: [T] bool, selected tokens (includes must_keep) local_rank: [T] int64, rank among kept tokens within each request keep_len: [B] int32, number of kept tokens per request """ device = must_keep.device T = int(must_keep.numel()) B = int(topk_budget.numel()) keep_mask = must_keep.clone() if T == 0 or B == 0: local_rank = torch.empty((T, ), device=device, dtype=torch.long) keep_len = torch.zeros((B, ), device=device, dtype=torch.int32) return keep_mask, local_rank, keep_len if max_len is None: L_max = int(lengths.max().item()) if lengths.numel() > 0 else 0 else: L_max = int(max_len) if L_max < 0: L_max = 0 must_keep_counts = torch.zeros((B, ), device=device, dtype=torch.long) must_keep_counts.scatter_add_(0, req_ids, must_keep.to(torch.long)) cand_counts = (lengths.to(torch.long) - must_keep_counts).clamp_min(0) k_eff = torch.minimum(topk_budget.to(torch.long).clamp_min(0), cand_counts) # CPU-known bound avoids a device->host sync; clamp for safety. if topk_budget_max is None: k_max = int(k_eff.max().item()) if k_eff.numel() > 0 else 0 else: k_max = int(topk_budget_max) if k_max < 0: k_max = 0 if k_max > L_max: k_max = L_max if k_max > 0: if token_scores is None: raise ValueError("token_scores must be provided when k_max > 0.") masked_scores = token_scores.to(torch.float32).masked_fill( must_keep, float("-inf")) scores_flat = masked_scores.new_full((B * L_max, ), float("-inf")) linear = req_ids * L_max + pos_in_req scores_flat[linear] = masked_scores scores = scores_flat.view(B, L_max) topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max] col_mask = torch.arange(k_max, device=device).unsqueeze(0) < k_eff.unsqueeze(1) global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B,k_max] flat_idx = global_sel.reshape(-1).clamp_(0, T - 1) flat_val = col_mask.reshape(-1).to(torch.int32) tmp = torch.zeros((T, ), device=device, dtype=torch.int32) tmp.scatter_add_(0, flat_idx, flat_val) keep_mask |= tmp > 0 keep_len = torch.zeros((B, ), device=device, dtype=torch.long) keep_len.scatter_add_(0, req_ids, keep_mask.to(torch.long)) # Stable, order-preserving local rank using segment-local prefix sums. keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T] start_minus_1 = (starts - 1).clamp_min(0) prefix_before_all = keep_prefix.index_select(0, start_minus_1) prefix_before = torch.where(starts > 0, prefix_before_all, torch.zeros_like(prefix_before_all)) # [B] prefix_before_per_token = prefix_before.index_select(0, req_ids) # [T] local_rank = keep_prefix - prefix_before_per_token - 1 # [T] return keep_mask, local_rank, keep_len.to(torch.int32)