Commit 2b7160c6 authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.0

parent fa718036
import math
from typing import Optional
import torch
import triton
from triton import language as tl
from compactor_vllm.compression.common import BaseCompressionMethod
from compactor_vllm.utils.helpers import maybe_execute_in_stream
from compactor_vllm.utils.triton_compat import autotune as triton_autotune
class SnapKVCompression(BaseCompressionMethod):
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
return None
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
scores = maybe_execute_in_stream(
query_aware_key_scores,
q,
k,
context.cu_seqlens_q,
context.cu_seqlens_k,
w=32,
STORE_STREAM=context.STORE_STREAM,
)
return scores
@triton_autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
)
for bq in [32, 64]
for bk in [32, 64]
for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
cache_results=True,
)
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b, # int32 pointers
out_m,
out_S, # [B, Hk, ROWS_MAX] float32
LOGITS, # [Nk, Hk, ROWS_MAX] float32
sm_scale, # float
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
):
# program ids
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2) # row-tile id
# batch segment bounds
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
# rows for this (b,hk)
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale = sm_scale * 1.4426950408889634
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
# map row -> (q_idx, hq_local)
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (
Q
+ q_idx[:, None] * STRIDE_Q_NQ
+ hq_glob[:, None] * STRIDE_Q_HQ
+ offs_d[None, :]
)
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
s = tl.where(kmask[None, :], s, -float("inf"))
# store into LOGITS[nk, hk, row] -> [BK, BQ]
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
# log2 streaming LSE update
cur_max = tl.max(s, 1) # [BQ]
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
# store m,S for these rows
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
@triton_autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64]
for bk in [32, 64, 128]
],
key=["HK", "HQ"],
cache_results=True,
)
@triton.jit
def _scores_from_logits_kernel(
cu_k,
w_b,
in_m,
in_S, # [B, Hk, ROWS_MAX] f32
LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits
OUT, # [Nk, Hk] f32
#
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
#
DO_POOL: tl.constexpr, # set True to enable in-place avg pool
KPOOL: tl.constexpr, # kernel size for avg pool (stride=1)
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
# === scores over computed region ===
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
scores = tl.zeros([BLOCK_K], dtype=tl.float32)
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
# load m, S for rows
m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
m = tl.load(
m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"),
)
S = tl.load(
S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
# load stored logits^T: [BK, BQ]
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
s_T = tl.load(
log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
) # [BK, BQ]
# probs^T = exp2(s_T - m) / S, sum over rows
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
scores += tl.sum(probs_T, 1) # [BK]
if DO_POOL and (KPOOL > 1):
i = tl.arange(0, BLOCK_K)[:, None]
j = tl.arange(0, BLOCK_K)[None, :]
band = (j <= i) & ((i - j) < KPOOL)
band = band & kmask[None, :]
# sum within band
sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1) # [BK]
denom = tl.sum(band, 1).to(tl.float32) # [BK]
denom = tl.where(denom > 0, denom, 1.0)
scores = sums / denom
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, scores, mask=kmask)
pad_beg = k_eff_end
pad_end = k_end
if pad_end > pad_beg:
for ks in tl.range(pad_beg, pad_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < pad_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(
out_ptrs, tl.full([BLOCK_K], float("inf"), dtype=tl.float32), mask=kmask
)
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
cache_results=True,
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT, # [Nk, Hk], float32
cu_k,
w_b, # [B+1], [B] int32
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr, # Hk
EPS: tl.constexpr, # e.g., 1e-12
BLOCK_K: tl.constexpr, # e.g., 128
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1], int32
cu_seqlens_k: torch.Tensor, # [B+1], int32
w: torch.Tensor | int, # [B], int32
sm_scale: float = None, # defaults to 1/sqrt(D)
*,
accum_scores: torch.Tensor = None,
accum_blending: float = None,
normalize: bool = False,
) -> Optional[torch.Tensor]:
assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = cu_seqlens_q.numel() - 1
assert B == cu_seqlens_k.numel() - 1
G = Hq // Hk
if type(w) is int:
max_w = w
w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
else:
max_w = int(w.max().item())
assert w.numel() == B
ROWS_MAX = max_w * G
if ROWS_MAX == 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
out = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
# strides
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(META):
return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
_lse_and_store_logits_kernel[grid](
q,
k,
cu_seqlens_q,
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=Hq // Hk,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=ROWS_MAX,
)
_scores_from_logits_kernel[(B, Hk)](
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
out,
QUERY_GROUP_SIZE=Hq // Hk,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_OUT_NK=STRIDE_OUT_NK,
STRIDE_OUT_HK=STRIDE_OUT_HK,
DO_POOL=True,
KPOOL=5,
)
if normalize:
_zscore_per_batch_epilogue[(B,)](
out,
cu_seqlens_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
if accum_scores is not None:
if accum_blending is not None:
accum_scores.mul_(accum_blending)
accum_scores.add_(out)
return accum_scores
else:
return out
RESERVED_BATCH = 0
# NOTE: Triton `tl.constexpr` is intended for use in kernel signatures/annotations.
# Some Triton builds reject passing `tl.constexpr(...)` objects as constexpr values.
# Keep the runtime value as a plain int and let kernel signatures declare constexpr.
TRITON_RESERVED_BATCH = RESERVED_BATCH
import os
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional
from transformers import AutoConfig
class AttentionBackend(Enum):
FLASH_ATTENTION = auto()
COMPACTOR_TRITON = auto()
@dataclass
class LLMConfig:
"""Configuration for the :class:`LLM` engine.
Parameters
----------
model : str
Hugging Face model identifier (e.g. ``"meta-llama/Meta-Llama-3-8B"``) or
a local model name that can be resolved by
:func:`transformers.AutoConfig.from_pretrained`.
path : str, optional
Local directory containing the model weights. If ``None``, the engine
will attempt to resolve a local snapshot for ``model`` using
:func:`huggingface_hub.snapshot_download`.
max_num_seqs : int, default 256
Upper bound on the number of concurrent batches that the scheduler and
KV-cache manager are allowed to handle. This affects the size of the
page table and some internal buffers.
max_model_len : int, default 40960
Maximum context length (in tokens) that the engine will allocate KV cache
and CUDA graphs for. During initialization this value is clamped to
``hf_config.max_position_embeddings`` for the chosen model.
gpu_memory_utilization : float, default 0.9
Fraction of the total GPU memory that may be used for KV cache and model
activations. Values should be in ``(0, 1]``. If this budget is too small,
the KV-cache manager may raise an error at warmup time due
to insufficient memory.
tensor_parallel_size : int, default 1
Number of tensor-parallel workers to shard the model
across. Must be between 1 and 8, and must evenly divide the model's
number of key/value heads.
enforce_eager : bool, default False
If ``True``, disable CUDA graph capture and always run the model in
eager mode during decoding. This reduces throughput. When ``False``,
the engine will capture and reuse CUDA graphs for supported
batch sizes and sequence lengths.
hf_config : transformers.AutoConfig, optional
Pre-loaded Hugging Face configuration for the model. If ``None``,
it will then be populated automatically based on ``model``.
eos : int, default -1
Primary stop token id (warmup / single-id paths). If ``-1``, the
:class:`LLM` constructor fills this and :attr:`eos_token_ids` from the
tokenizer.
eos_token_ids : list of int, optional
All token ids that terminate generation (e.g. HF tokenizers may expose
``eos_token_id`` as a list for chat models). If ``None``, inferred in
:class:`LLM` from the tokenizer and model type.
kvcache_page_size : int, default 128
Number of tokens stored in a single KV-cache page. Smaller pages improve
allocation flexibility but increase page-table overhead; larger pages
reduce overhead but have coarser granularity.
leverage_sketch_size : int, default 48
Sketch dimension used by the Compactor leverage-score estimator.
attention_backend : AttentionBackend, default AttentionBackend.COMPACTOR_TRITON
Attention implementation to use. ``COMPACTOR_TRITON`` selects the custom
Triton kernels used by Compactor; ``FLASH_ATTENTION`` selects the
FlashAttention3 varlen backend. The COMPACTOR_TRITON tends to be faster
for longer sequence lengths, while FA3 is faster at shorter lengths.
"""
model: str
path: Optional[str] = None
nccl_port: Optional[int] = 1218
max_num_seqs: int = 256
max_model_len: int = 40960
gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
eos_token_ids: Optional[List[int]] = None
kvcache_page_size: int = 128
leverage_sketch_size: int = 48
attention_backend: AttentionBackend = AttentionBackend.COMPACTOR_TRITON
show_progress_bar: bool = True
def __post_init__(self):
if self.path is not None and not os.path.isdir(self.path):
raise NotADirectoryError(f"Engine config dir {self.path} does not exist")
if self.tensor_parallel_size <= 0 or self.tensor_parallel_size > 8:
assert 1 <= self.tensor_parallel_size <= 8
raise ValueError("tensor_parallel_size must be >= 1 and <= 8")
if self.hf_config is None:
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(
self.max_model_len, self.hf_config.max_position_embeddings
)
from dataclasses import dataclass
@dataclass
class SamplingParams:
temperature: float = 1.0
max_new_tokens: int = 256
def __post_init__(self):
if self.temperature < 0:
raise ValueError("Temperature cannot be negative")
import atexit
import inspect
import logging
from typing import Any, List, Optional, Union
import torch.multiprocessing as mp
from compactor_vllm.compression.compression_config import (
BatchCompressionParams,
SequenceCompressionParams,
)
from compactor_vllm.config.engine_config import LLMConfig
from compactor_vllm.config.sampling_params import SamplingParams
from compactor_vllm.core.model_runner import ModelRunner
from compactor_vllm.models import MODEL_REGISTRY
from compactor_vllm.utils.sequence import Sequence
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
PromptLike = Union[str, List[int]]
def _infer_stop_token_ids(tokenizer, hf_config) -> list[int]:
"""
Build the set of token ids that should end generation.
Newer HF chat tokenizers often expose ``eos_token_id`` as a *list* of ids.
The engine must not compare generated ids to that list as a single ``int``;
see :attr:`LLMConfig.eos_token_ids` and decode-time ``torch.isin``.
Qwen chat uses ``</think>`` (im_end) as the assistant turn boundary; include it
when present in ``additional_special_tokens`` / ``added_tokens_encoder``. We
avoid loose substring matches like ``\"end\"`` that can tag unrelated tokens.
"""
raw = tokenizer.eos_token_id
ids: list[int] = []
if isinstance(raw, (list, tuple)):
ids.extend(int(x) for x in raw)
elif raw is not None:
ids.append(int(raw))
unk_id = getattr(tokenizer, "unk_token_id", None)
def _maybe_add_tid(tid: int) -> None:
if not isinstance(tid, int) or tid < 0:
return
if unk_id is not None and tid == unk_id:
return
if tid not in ids:
ids.append(tid)
model_type = getattr(hf_config, "model_type", None)
if model_type in ("qwen2", "qwen3", "qwen2_moe", "qwen3_moe"):
enc = getattr(tokenizer, "added_tokens_encoder", None)
if isinstance(enc, dict):
for key, tid in enc.items():
if isinstance(key, str) and "im_end" in key:
_maybe_add_tid(int(tid))
for extra in getattr(tokenizer, "additional_special_tokens", []) or []:
if not isinstance(extra, str) or "im_end" not in extra:
continue
try:
tid = tokenizer.convert_tokens_to_ids(extra)
except (TypeError, ValueError, KeyError):
continue
_maybe_add_tid(tid)
if not ids:
raise ValueError(
"Could not infer stop token ids from the tokenizer; set "
"LLMConfig(eos_token_ids=[...]) explicitly."
)
return ids
def _merge_apply_chat_template_kwargs(
tokenizer,
user_kwargs: Optional[dict[str, Any]],
) -> dict[str, Any]:
"""
Merge user kwargs with defaults for HF chat templates that support them.
Qwen3 (and similar) instruct models expect `add_generation_prompt=True` so
the first generated token continues the assistant turn; without it, output
can repeat punctuation / template fragments. `enable_thinking=False` avoids
the Qwen3 reasoning channel when the tokenizer supports it.
"""
out = dict(user_kwargs or {})
try:
sig = inspect.signature(tokenizer.apply_chat_template)
except (TypeError, ValueError):
return out
if "add_generation_prompt" in sig.parameters and "add_generation_prompt" not in out:
out["add_generation_prompt"] = True
if "enable_thinking" in sig.parameters and "enable_thinking" not in out:
out["enable_thinking"] = False
return out
def _runner_entry(config: LLMConfig, rank: int, evt):
runner = None
try:
runner = ModelRunner(config, rank, evt)
runner.loop()
except Exception as e:
logging.exception(f"Rank {rank}: {repr(e)}")
finally:
if runner is not None:
runner.exit()
class LLMEngine:
"""High-level engine coordinating model runners and scheduling"""
def __init__(self, config: LLMConfig):
self.config = config
if self.config.hf_config.model_type not in MODEL_REGISTRY:
raise ValueError(f"Unknown model {self.config.model}")
if config.path is None:
from huggingface_hub import snapshot_download
self.config.path = snapshot_download(
repo_id=config.model, local_files_only=True
)
logger.info(f"Using {self.config.model} snapshot @ {self.config.path}")
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model, use_fast=True)
if self.config.eos_token_ids is None:
if self.config.eos != -1:
self.config.eos_token_ids = [int(self.config.eos)]
else:
self.config.eos_token_ids = _infer_stop_token_ids(
self.tokenizer, self.config.hf_config
)
else:
self.config.eos_token_ids = [int(x) for x in self.config.eos_token_ids]
self.config.eos_token_ids = sorted(set(self.config.eos_token_ids))
if self.config.eos == -1:
self.config.eos = int(self.config.eos_token_ids[0])
else:
self.config.eos = int(self.config.eos)
if self.config.eos not in self.config.eos_token_ids:
self.config.eos_token_ids = sorted(
self.config.eos_token_ids + [self.config.eos]
)
self.ps = []
world_size = int(self.config.tensor_parallel_size)
self.events = []
if world_size > 1:
ctx = mp.get_context("spawn")
for r in range(1, world_size):
event = ctx.Event()
p = ctx.Process(
target=_runner_entry,
args=(self.config, r, event),
daemon=True,
)
p.start()
self.ps.append(p)
self.events.append(event)
self.master_model_runner = ModelRunner(
self.config, rank=0, peer_events=self.events
)
atexit.register(self.exit)
def exit(self):
if getattr(self, "_exited", False):
return
self._exited = True
runner = getattr(self, "master_model_runner", None)
if runner is not None:
try:
runner.exit()
except Exception:
logger.exception("Failed to exit master ModelRunner cleanly")
for p in self.ps:
if p.is_alive():
p.terminate()
p.join(timeout=1.0)
if hasattr(self, "events"):
self.events.clear()
def tokenize_prompt(self, prompt: PromptLike, **tokenizer_kwargs) -> List[int]:
"""
Turn a raw prompt into token IDs.
"""
if isinstance(prompt, str):
return self.tokenizer(prompt, **tokenizer_kwargs)["input_ids"]
else:
return list(prompt)
def detokenize_prompt(
self, sequences: List[Sequence], **detokenizer_kwargs
) -> List[str]:
"""
Turn completed Sequences into strings.
"""
defaults: dict[str, Any] = {"skip_special_tokens": True}
merged = {**defaults, **detokenizer_kwargs}
return self.tokenizer.batch_decode(
[s.completion_token_ids for s in sequences], **merged
)
def _build_sequences(
self,
prompts: List[PromptLike] | PromptLike,
sampling_params: SamplingParams | List[SamplingParams],
per_sequence_compression_params: Optional[
SequenceCompressionParams | List[SequenceCompressionParams]
] = None,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
) -> List[Sequence]:
"""
Build Sequence objects from prompts, sampling params, and optional
per-sequence compression parameters.
"""
tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
if not isinstance(prompts, list):
prompts = [prompts]
if isinstance(sampling_params, SamplingParams):
sampling_params_list: List[SamplingParams] = [sampling_params] * len(
prompts
)
else:
sampling_params_list = sampling_params
assert len(sampling_params_list) == len(prompts), (
"sampling_params list must match prompts length"
)
if per_sequence_compression_params is None:
compression_params_list: List[SequenceCompressionParams] = [
SequenceCompressionParams(1.0) for _ in prompts
]
elif isinstance(per_sequence_compression_params, SequenceCompressionParams):
compression_params_list = [per_sequence_compression_params] * len(prompts)
else:
# list-like
assert len(per_sequence_compression_params) == len(prompts), (
"per_sequence_compression_params list must match prompts length"
)
compression_params_list = list(per_sequence_compression_params)
seqs: List[Sequence] = []
for prompt, sparams, cparams in zip(
prompts, sampling_params_list, compression_params_list
):
token_ids = self.tokenize_prompt(prompt, **tokenizer_kwargs)
if cparams.protected_first_tokens + cparams.protected_last_tokens >= len(token_ids):
cparams.compression_ratio = 1.0
seqs.append(
Sequence(
prompt_token_ids=token_ids,
sampling_params=sparams,
compression_params=cparams,
)
)
return seqs
def generate(
self,
prompts: List[PromptLike] | PromptLike,
sampling_params: SamplingParams | List[SamplingParams],
batch_compression_params: BatchCompressionParams,
*,
per_sequence_compression_params: Union[
List[SequenceCompressionParams], SequenceCompressionParams
] = None,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
detokenizer_kwargs: Optional[dict[str, Any]] = None,
return_sequences: bool = False,
) -> List[str] | tuple[List[str], List[Sequence]]:
"""
Accept prompts and return completed Sequences.
Args:
:param prompts:
Single prompt or list of prompts, each either a raw text prompt,
or pre-tokenized input IDs.
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Compression settings for this batch.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
prompt in the batch.
:param tokenizer_kwargs:
Extra kwargs forwarded to ``tokenizer(...)`` when tokenizing
string prompts.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[Sequence]:
One Sequence per input prompt, with `completion_token_ids`
filled in after generation.
"""
tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
seqs = self._build_sequences(
prompts,
sampling_params=sampling_params,
per_sequence_compression_params=per_sequence_compression_params,
tokenizer_kwargs=tokenizer_kwargs,
)
self.master_model_runner.generate(seqs, batch_compression_params)
output_strings = self.detokenize_prompt(seqs, **detokenizer_kwargs)
if return_sequences:
return output_strings, seqs
return output_strings
def generate_chat(
self,
messages_batch: List[List[dict]],
sampling_params: SamplingParams | List[SamplingParams],
batch_compression_params: BatchCompressionParams,
per_sequence_compression_params: Union[
SequenceCompressionParams, List[SequenceCompressionParams]
],
*,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
detokenizer_kwargs: Optional[dict[str, Any]] = None,
return_sequences: bool = False,
) -> List[str] | tuple[List[str], List[Sequence]]:
"""
Convenience API for chat-style prompts using HF `apply_chat_template`.
Args:
:param messages_batch:
List of conversations, where each conversation is a list of
message dicts like:
{"role": "system" | "user" | "assistant", "content": str}
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Batch Level compression settings. Can set compression_method.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
conversation in the batch.
:param tokenizer_kwargs:
Passed through to `tokenizer.apply_chat_template`.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[str] or tuple[List[str], List[Sequence]]:
One string per conversation.
"""
prompts_token_ids: List[List[int]] = []
tokenizer_kwargs = _merge_apply_chat_template_kwargs(
self.tokenizer, tokenizer_kwargs
)
detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
for messages in messages_batch:
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
**tokenizer_kwargs,
)
if hasattr(input_ids, "tolist"):
input_ids = input_ids.tolist()
prompts_token_ids.append(input_ids)
return self.generate(
prompts_token_ids,
sampling_params=sampling_params,
batch_compression_params=batch_compression_params,
per_sequence_compression_params=per_sequence_compression_params,
tokenizer_kwargs=tokenizer_kwargs,
detokenizer_kwargs=detokenizer_kwargs,
return_sequences=return_sequences,
)
def generate_from_sequences(
self,
seqs: List[Sequence],
batch_compression_params: BatchCompressionParams,
) -> List[Sequence]:
"""
Args:
:param seqs:
List of Sequence instances
:param batch_compression_params:
Compression settings.
Returns:
:return List[Sequence]:
Same list, mutated in-place with completions.
"""
self.master_model_runner.generate(seqs, batch_compression_params)
return seqs
import logging
from typing import Iterable, List, Optional
import torch
import torch.distributed as dist
from compactor_vllm.config.engine_config import LLMConfig
from compactor_vllm.kv_cache.page_table import KVAllocationStatus, PagedKVCache
from torch import nn
logger = logging.getLogger(__name__)
class KVCacheManager:
def __init__(self, rank: int, config: LLMConfig):
super().__init__()
hf_config = config.hf_config
self.rank = rank
self.gpu_frac = config.gpu_memory_utilization
self.page_size = config.kvcache_page_size
self.world_size = config.tensor_parallel_size
self.max_num_batches = config.max_num_seqs
self.max_model_len = config.max_model_len
self.num_layers = hf_config.num_hidden_layers
self.model_dtype = hf_config.torch_dtype
self.head_dim = getattr(hf_config, "head_dim", None)
self.max_pages_per_batch = (
self.max_model_len + self.page_size - 1
) // self.page_size
self.num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
assert hf_config.num_key_value_heads % dist.get_world_size() == 0, (
"world size needs to divide num_kv_heads"
)
self.num_pages = None
self.paged_cache: Optional[PagedKVCache] = None
self.max_batched_tokens = None
self.seq_id_to_batch = {}
def allocate_sequences(
self, seq_ids: List[int], max_positions: List[int]
) -> (bool, Optional[torch.Tensor]):
batch_mapping = []
for seq_id, len_to_alloc in zip(seq_ids, max_positions):
if seq_id not in self.seq_id_to_batch:
batch_id = self.paged_cache.new_batch()
if batch_id is None:
logger.warning("Failed to allocate batch!")
return False, None
self.seq_id_to_batch[seq_id] = int(batch_id)
batch_mapping.append(self.seq_id_to_batch[seq_id])
if (
alloc_status := self.paged_cache.reserve_tokens(
self.seq_id_to_batch[seq_id], len_to_alloc
)
) != KVAllocationStatus.SUCCESS:
logger.warning(f"Failed to allocate pages ({alloc_status})!")
return False, None
batch_mapping = torch.as_tensor(batch_mapping, dtype=torch.int32, device="cuda")
return True, batch_mapping
def free_sequences(self, seq_ids: Iterable[int]):
for seq_id in seq_ids:
global_batch_id = self.seq_id_to_batch.pop(seq_id, None)
self.paged_cache.free_batch(global_batch_id)
def init_cache(self, model: nn.Module):
self.num_pages = self.get_num_pages(self.gpu_frac, self.max_pages_per_batch)
self.paged_cache = PagedKVCache(
num_layers=self.num_layers,
H_kv=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size,
num_pages=int(self.num_pages),
max_num_batches=self.max_num_batches,
device=f"cuda:{self.rank}",
dtype=self.model_dtype,
max_logical_pages_per_head=int(self.max_pages_per_batch),
)
self._assign_cache_to_layers(model)
def _assign_cache_to_layers(self, model) -> None:
for layer_index, layer in enumerate(model.model.layers):
attn = layer.self_attn.attn
k, v, pt, bh = self.paged_cache.layer_slices(layer_index)
attn.k_cache = k
attn.v_cache = v
attn.page_table = pt
attn.bh_seq_lens = bh
attn.page_size = self.page_size
def get_num_pages(self, frac: float, n_logical_pages_max: int):
free, total = torch.cuda.mem_get_info()
used = total - free
stats = torch.cuda.memory_stats()
peak = int(stats["allocated_bytes.all.peak"])
current = int(stats["allocated_bytes.all.current"])
bytes_for_kv_budget = int(total * frac * 0.9) - used - peak + current
if bytes_for_kv_budget <= 0:
raise RuntimeError(
f"Insufficient memory for KV cache."
f"Try increasing gpu_memory_utilization (currently {frac:.2f})."
)
# page_table[L, B, H_kv, N_LOGICAL_PAGES_MAX] + bh_seq_lens[L, B, H_kv]
int32_sz = torch.empty((), dtype=torch.int32).element_size() # 4
page_table_bytes_per_layer = (
self.max_num_batches
* self.num_kv_heads
* n_logical_pages_max
* int32_sz # page_table
+ self.max_num_batches * self.num_kv_heads * int32_sz
)
total_page_table_bytes = self.num_layers * page_table_bytes_per_layer
kv_bytes_net = bytes_for_kv_budget - total_page_table_bytes
if kv_bytes_net <= 0:
raise RuntimeError(
"page-table footprint exceeds KV cache budget. "
f"reduce max_num_seqs ({self.max_num_batches}) "
f"or increase kv_cache_mem_fraction (currently {frac:.2f})."
)
dtype_sz = torch.empty((), dtype=self.model_dtype).element_size()
bytes_per_page_across_layers = self.num_layers * (
2 * self.page_size * self.head_dim * dtype_sz
)
return max(1, kv_bytes_net // bytes_per_page_across_layers)
def estimate_max_batched_tokens(
self,
warmup_tokens: int,
bytes_used_before_warmup: int,
bytes_peak_after_warmup: int,
) -> int:
"""
Estimate the max total number of tokens that can be processed concurrently
without OOM.
"""
assert warmup_tokens > 0, "warmup_tokens must be > 0"
# activation bytes per token
warmup_delta = max(
0, int(bytes_peak_after_warmup) - int(bytes_used_before_warmup)
)
bytes_per_token = max(1, (warmup_delta + warmup_tokens - 1) // warmup_tokens)
free, total = torch.cuda.mem_get_info()
target = int(total * self.gpu_frac)
used_now = int(total - free)
# reserve headroom equal to the gap between peak and current allocations seen so far
stats = torch.cuda.memory_stats()
peak_cur = int(stats.get("allocated_bytes.all.peak", 0))
cur_now = int(stats.get("allocated_bytes.all.current", 0))
cushion = max(0, peak_cur - cur_now)
activation_budget = int(max(0, target - used_now - cushion) * 0.95)
max_tokens_per_batch = activation_budget // bytes_per_token
max_tokens_in_cache = (self.num_pages * self.page_size) // self.num_kv_heads
# round to lower multiple of page size
max_tokens_per_batch = (max_tokens_per_batch // self.page_size) * self.page_size
max_tokens_in_cache = (max_tokens_in_cache // self.page_size) * self.page_size
self.max_batched_tokens = min(max_tokens_in_cache, max_tokens_per_batch)
return self.max_batched_tokens
@property
def num_free_batches(self) -> int:
return len(self.paged_cache.free_batches)
@property
def num_free_pages(self) -> int:
return min(len(fp) for fp in self.paged_cache.free_pages)
def reclaim_pages(
self,
seq_ids_to_reclaim: Iterable[int],
future_reserved_buffer: List[int] | torch.Tensor,
) -> int:
approximate_bytes_freed = 0
for i, seq_id in enumerate(seq_ids_to_reclaim):
batch_idx = self.seq_id_to_batch[seq_id]
approximate_bytes_freed += self.paged_cache.reclaim_pages(
batch_idx, future_reserved_buffer[i]
)
return approximate_bytes_freed
import atexit
import logging
import inspect
from typing import List, Optional
import torch
import torch.distributed as dist
from compactor_vllm.attention.sparse_decode_kernel import num_splits_heuristic
from compactor_vllm.compression.compression_config import BatchCompressionParams
from compactor_vllm.config.constants import RESERVED_BATCH
from compactor_vllm.config.engine_config import AttentionBackend, LLMConfig
from compactor_vllm.core.memory_manager import KVCacheManager
from compactor_vllm.core.scheduler import Scheduler
from compactor_vllm.layers.sampler import Sampler
from compactor_vllm.models import MODEL_REGISTRY
from compactor_vllm.utils.arguments import (
DecodeBatchArguments,
DecodeBatchOutput,
PackedTensorArguments,
PrefillBatchArguments,
)
from compactor_vllm.utils.context import CompressionContext, reset_context, set_context
from compactor_vllm.utils.sequence import Sequence
from torch.multiprocessing import Event
from tqdm import tqdm
logger = logging.getLogger(__name__)
class ModelRunner:
"""Per-rank execution loop. Manages model, sampler, KV cache, and warmup"""
def __init__(
self,
config: LLMConfig,
rank: int,
batch_ready: Optional[Event] = None,
peer_events: List[Event] = None,
):
self.rank = rank
self.config = config
_dev = torch.device(f"cuda:{rank}")
assert config.eos_token_ids is not None and len(config.eos_token_ids) > 0, (
"LLMConfig.eos_token_ids must be set (filled in LLMEngine from tokenizer)."
)
self._stop_token_ids = torch.tensor(
config.eos_token_ids, dtype=torch.int64, device=_dev
)
hf_config = config.hf_config
self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
self.leverage_sketch_size = config.leverage_sketch_size
self.show_progress_bar = config.show_progress_bar
self.max_num_batches = config.max_num_seqs
self.max_model_len = config.max_model_len
self.num_layers = hf_config.num_hidden_layers
self.model_dtype = hf_config.torch_dtype
self.head_dim = getattr(hf_config, "head_dim", None)
init_kwargs = {}
if "device_id" in inspect.signature(dist.init_process_group).parameters:
init_kwargs["device_id"] = torch.device(f"cuda:{rank}")
dist.init_process_group(
"nccl",
f"tcp://localhost:{config.nccl_port}",
world_size=self.world_size,
rank=rank,
**init_kwargs,
)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
model_type = hf_config.model_type
self.model = MODEL_REGISTRY[model_type](hf_config)
self.model.load_model(
config.path, use_tqdm=self.is_master and self.show_progress_bar
)
self.sampler = Sampler()
pre_warmup_mem = torch.cuda.memory_stats().get("allocated_bytes.all.current", 0)
self.warmup(
num_warmup_tokens=self.max_model_len,
attention_backend=AttentionBackend.FLASH_ATTENTION,
)
post_warmup_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
self.kv_manager = KVCacheManager(rank, config)
self.kv_manager.init_cache(self.model)
self.store_stream: Optional[torch.cuda.Stream] = torch.cuda.Stream()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
self.batch_ready = batch_ready
self.peer_events = peer_events if peer_events is not None else []
self.captured_graphs = {}
self.min_captured_len = {}
self.max_batched_tokens = self.kv_manager.estimate_max_batched_tokens(
self.max_model_len, pre_warmup_mem, post_warmup_peak
)
if self.is_master:
logger.info(f"Estimated max batched tokens of {self.max_batched_tokens}")
if self.config.attention_backend == AttentionBackend.COMPACTOR_TRITON:
self.warmup(
num_warmup_tokens=self.max_model_len,
attention_backend=AttentionBackend.COMPACTOR_TRITON,
)
if not self.enforce_eager:
bs = [1 << i for i in range(self.max_num_batches.bit_length())]
for bs in (
tqdm(bs, desc="Capturing CUDA Graphs")
if self.is_master and self.show_progress_bar
else bs
):
for seq_len in [1024, 4096, 8192, 16384]:
self.capture_cudagraph(bs, seq_len)
self.packed_args = PackedTensorArguments(
rank=self.rank,
max_batched_tokens=self.max_batched_tokens,
config=self.config,
)
atexit.register(self.exit)
@torch.inference_mode()
def warmup(self, num_warmup_tokens: int, attention_backend: AttentionBackend):
if self.rank == 0:
if attention_backend == AttentionBackend.COMPACTOR_TRITON:
backend_name = "Compactor Triton"
else:
backend_name = "Flash"
logger.info(f"Warming up with {backend_name} Attention Backend")
device = torch.device(f"cuda:{self.rank}")
input_ids = torch.tensor(
[self.config.eos] * num_warmup_tokens, device=device, dtype=torch.int64
)
positions = torch.arange(num_warmup_tokens, device=device, dtype=torch.int64)
cu_seqlens_q = torch.tensor(
[0, num_warmup_tokens], device=device, dtype=torch.int32
)
cu_seqlens_k = torch.tensor(
[0, num_warmup_tokens], device=device, dtype=torch.int32
)
if attention_backend == AttentionBackend.COMPACTOR_TRITON:
success, batch_mapping = self.kv_manager.allocate_sequences(
[-1], [num_warmup_tokens]
)
assert success
else:
batch_mapping = None
set_context(
is_prefill=True,
do_compression=False,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=num_warmup_tokens,
max_seqlen_k=num_warmup_tokens,
batch_mapping=batch_mapping,
attention_backend=attention_backend,
)
for _ in range(2):
torch.cuda.reset_peak_memory_stats()
self.model.compute_logits(self.model(input_ids, positions))
dist.barrier()
if attention_backend == AttentionBackend.COMPACTOR_TRITON:
self.kv_manager.paged_cache.bh_seq_lens.index_fill_(
1, batch_mapping.to(torch.long), 0
)
reset_context()
if attention_backend == AttentionBackend.COMPACTOR_TRITON:
self.kv_manager.free_sequences([-1])
def exit(self):
if getattr(self, "_exited", False):
return
self._exited = True
try:
if hasattr(self, "captured_graphs"):
self.captured_graphs.clear()
finally:
if dist.is_initialized():
dist.destroy_process_group()
def loop(self):
while True:
if self.batch_ready.wait(1.0):
self._process_batches_peer()
@torch.inference_mode()
def run_prefill(
self, prefill_args: PrefillBatchArguments, batch_mapping: torch.Tensor
):
assert prefill_args.B > 0 and prefill_args.N > 0
max_bh_len = (
self.kv_manager.paged_cache.bh_seq_lens.index_select(1, index=batch_mapping)
.max()
.item()
)
compression_context = CompressionContext(
compression_method=prefill_args.compression_method,
compression_chunk_size=prefill_args.compression_chunk_size,
batch_tokens_to_retain=prefill_args.batch_tokens_to_retain,
max_tokens_to_retain=prefill_args.max_tokens_to_retain,
context_lens=prefill_args.context_lens.tolist(),
PHI=prefill_args.PHI,
sketch_dimension=self.leverage_sketch_size,
protected_first_tokens=prefill_args.protected_first,
protected_last_tokens=prefill_args.protected_last,
compression_ratio=prefill_args.compression_ratio,
)
set_context(
is_prefill=True,
do_compression=prefill_args.do_compression,
cu_seqlens_q=prefill_args.cu_seqlens_q,
cu_seqlens_k=prefill_args.cu_seqlens_k,
max_seqlen_q=prefill_args.max_seqlen_q,
max_seqlen_k=prefill_args.max_seqlen_k,
batch_mapping=batch_mapping,
max_bh_len=max_bh_len,
compression_context=compression_context,
STORE_STREAM=self.store_stream,
attention_backend=self.config.attention_backend,
)
logits = self.model.compute_logits(
self.model(prefill_args.input_ids, prefill_args.positions)
)
reset_context()
return logits
def maybe_broadcast(self, tensor: torch.Tensor):
if self.world_size > 1:
return dist.broadcast(tensor, src=0)
return None
def maybe_release_peers(self, do_release=False):
if self.world_size > 1:
if self.is_master:
if do_release:
for event in self.peer_events:
event.clear()
dist.barrier()
else:
dist.barrier()
@torch.inference_mode()
def generate(
self,
all_sequences: List[Sequence],
batch_compression_params: Optional[BatchCompressionParams] = None,
):
assert self.is_master, "generate can only be called on the master process"
for begin_execution_event in self.peer_events:
begin_execution_event.set()
if batch_compression_params is None:
batch_compression_params = BatchCompressionParams()
self._process_batches_master(all_sequences, batch_compression_params)
@property
def is_master(self):
return self.rank == 0
@torch.inference_mode()
def _process_batches_master(
self,
all_sequences: List[Sequence],
batch_compression_params: BatchCompressionParams,
):
assert self.is_master
compression_details = f"Applying Compression Method: {batch_compression_params.compression_method}"
if any(seq.compression_params.compression_ratio < 1.0 for seq in all_sequences):
logger.info(compression_details)
scheduler = Scheduler(
all_sequences=all_sequences,
kv_manager=self.kv_manager,
use_tqdm=self.show_progress_bar,
)
decode_batch = DecodeBatchArguments()
decode_flags = torch.empty(2, dtype=torch.int32, device="cuda")
while not scheduler.is_finished():
sequences = scheduler.get_prefill_batch()
seq_ids_cpu = [seq.seq_id for seq in sequences]
scheduler.add_running_sequence_ids(seq_ids_cpu, update_status=True)
temps = torch.tensor(
[s.sampling_params.temperature for s in sequences],
dtype=torch.float32,
pin_memory=True,
).cuda(non_blocking=True)
prefill_arguments = self.packed_args.build_prefill_args(
sequences, batch_compression_params=batch_compression_params
)
max_ctx_lens = (
prefill_arguments.max_new_tokens + prefill_arguments.context_lens
)
success, batch_mapping = self.kv_manager.allocate_sequences(
seq_ids_cpu, max_ctx_lens.tolist()
)
assert success, "failed to allocate pages for sequences"
logits = self.run_prefill(prefill_arguments, batch_mapping)
# Must match prefill `positions` dtype (int64). `context_lens` is int32
# from the packed buffer; using int32 here breaks RoPE indexing
# (`cos_sin_cache[positions]`) on CUDA for decode vs prefill.
positions = prefill_arguments.context_lens.to(dtype=torch.int64)
token_ids = self.sampler(logits, temps)
# Prefill KV writes + bh_seq_lens updates run on STORE_STREAM; reclaim
# reads bh_seq_lens on the default stream and must not race.
if self.store_stream is not None:
torch.cuda.default_stream().wait_stream(self.store_stream)
# TODO: synchronize page counts accross dist
if self.world_size == 1:
self.kv_manager.reclaim_pages(
seq_ids_cpu, prefill_arguments.max_new_tokens
)
# with logging_redirect_tqdm():
# logger.info(
# f"Reclaimed {reclaimed_bytes / 1e6:.2f} MB from the KV cache"
# )
if scheduler.any_pending_sequences():
num_pending_batches = (
0
if decode_batch.token_ids is None
else decode_batch.token_ids.shape[0]
)
occupancy = int((num_pending_batches + len(seq_ids_cpu)) * 0.66)
else:
occupancy = -1
run_decode = not scheduler.can_prefill_another_batch()
decode_batch = decode_batch.update(
batch_mapping,
token_ids,
positions,
max_ctx_lens,
prefill_arguments.seq_ids,
temps,
occupancy,
)
if self.world_size > 1:
decode_flags[0] = int(run_decode)
decode_flags[1] = occupancy
self.maybe_broadcast(decode_flags)
if not run_decode:
continue
if self.store_stream is not None:
torch.cuda.default_stream().wait_stream(self.store_stream)
decode_output, decode_batch = self.run_decode_loop(decode_batch)
finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
decode_batch.seq_ids.tolist()
)
scheduler.record_finished_sequence_ids(
finished_sequence_ids, update_status=True
)
self.kv_manager.free_sequences(finished_sequence_ids)
self.maybe_release_peers(scheduler.is_finished())
scheduler.update_sequences(
decode_output.output_tokens.tolist(),
decode_output.output_seq_ids.tolist(),
)
scheduler.close()
@torch.inference_mode()
def _process_batches_peer(self):
assert not self.is_master
scheduler = Scheduler([], kv_manager=self.kv_manager)
decode_batch = DecodeBatchArguments()
decode_flags = torch.empty(2, dtype=torch.int32, device="cuda")
while self.batch_ready.is_set():
prefill_arguments = self.packed_args.build_prefill_args()
B = prefill_arguments.B
max_ctx_lens = (
prefill_arguments.max_new_tokens + prefill_arguments.context_lens
)
seq_ids_cpu = prefill_arguments.seq_ids.tolist()
scheduler.add_running_sequence_ids(seq_ids_cpu)
success, batch_mapping = self.kv_manager.allocate_sequences(
seq_ids_cpu, max_ctx_lens.tolist()
)
assert success, "failed to allocate pages for sequences"
self.run_prefill(prefill_arguments, batch_mapping)
positions = prefill_arguments.context_lens.to(dtype=torch.int64)
self.maybe_broadcast(decode_flags)
run_decode = bool(decode_flags[0].item())
occupancy = int(decode_flags[1].item())
token_ids = torch.empty(B, dtype=torch.int64, device="cuda")
decode_batch = decode_batch.update(
batch_mapping,
token_ids,
positions,
max_ctx_lens,
prefill_arguments.seq_ids,
None, # temps not used in peer process
occupancy,
)
if not run_decode:
continue
if self.store_stream is not None:
torch.cuda.default_stream().wait_stream(self.store_stream)
_, decode_batch = self.run_decode_loop(decode_batch)
finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
decode_batch.seq_ids.tolist()
)
scheduler.record_finished_sequence_ids(finished_sequence_ids)
self.kv_manager.free_sequences(finished_sequence_ids)
self.maybe_release_peers()
scheduler.close()
@torch.inference_mode()
def run_decode_loop(
self,
decode_batch: DecodeBatchArguments,
) -> tuple[DecodeBatchOutput, DecodeBatchArguments]:
if self.is_master:
num_stashed_batches = decode_batch.num_stashed_batches
tok_buffer = [
decode_batch.token_ids[num_stashed_batches:].to(
"cpu", non_blocking=True
)
]
seq_buffer = [
decode_batch.seq_ids[num_stashed_batches:].to("cpu", non_blocking=True)
]
while True:
self.maybe_broadcast(decode_batch.token_ids)
not_stopped = ~torch.isin(decode_batch.token_ids, self._stop_token_ids)
running_batches = (decode_batch.positions < decode_batch.max_ctx_lens) & (
not_stopped
)
decode_batch.token_ids = torch.masked_select(
decode_batch.token_ids, running_batches
)
decode_batch.positions = torch.masked_select(
decode_batch.positions, running_batches
)
decode_batch.batch_mapping = torch.masked_select(
decode_batch.batch_mapping, running_batches
)
decode_batch.max_ctx_lens = torch.masked_select(
decode_batch.max_ctx_lens, running_batches
)
decode_batch.seq_ids = torch.masked_select(
decode_batch.seq_ids, running_batches
)
if self.is_master:
decode_batch.temps = torch.masked_select(
decode_batch.temps, running_batches
)
num_remaining = decode_batch.token_ids.numel()
if (
num_remaining == 0
or num_remaining <= decode_batch.desired_batch_occupancy
):
decode_batch.num_stashed_batches = num_remaining
break
if self.enforce_eager:
set_context(
is_prefill=False,
do_compression=False,
batch_mapping=decode_batch.batch_mapping,
)
logits = self.model.compute_logits(
self.model(decode_batch.token_ids, decode_batch.positions)
)
else:
logits = self.run_graph_decode(
decode_batch.token_ids,
decode_batch.positions,
decode_batch.batch_mapping,
)
if self.is_master:
decode_batch.token_ids = self.sampler(logits, decode_batch.temps)
tok_buffer.append(decode_batch.token_ids.to("cpu", non_blocking=True))
seq_buffer.append(decode_batch.seq_ids.to("cpu", non_blocking=True))
decode_batch.positions += 1
if self.is_master:
# non_blocking D2H copies must finish before cat/tolist read CPU data.
torch.cuda.synchronize()
output = DecodeBatchOutput(
output_tokens=torch.cat(tok_buffer),
output_seq_ids=torch.cat(seq_buffer),
)
else:
output = DecodeBatchOutput(None, None)
return output, decode_batch
@torch.inference_mode()
def run_graph_decode(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
batch_mapping: torch.Tensor,
):
set_context(
is_prefill=False,
do_compression=False,
batch_mapping=batch_mapping,
)
bs = input_ids.shape[0]
graph_dict = self.get_cuda_graph(bs, int(positions.max()))
graph_dict["input_ids"][:bs] = input_ids
graph_dict["positions"][:bs] = positions
graph_dict["batch_mapping"].fill_(RESERVED_BATCH)
graph_dict["batch_mapping"][:bs] = batch_mapping
graph_dict["graph"].replay()
return (
graph_dict["logits"][:bs]
if graph_dict["logits"] is not None
else graph_dict["logits"]
)
@torch.inference_mode()
def capture_cudagraph(self, batch_size: int, max_seqlen_k: int):
dist.barrier()
device = torch.device("cuda")
logger.debug(
f"Capturing CUDA graph for batch size {batch_size} ({max_seqlen_k} tokens)"
)
_g_input_ids = torch.zeros(batch_size, dtype=torch.int32, device=device)
_g_positions = torch.zeros(batch_size, dtype=torch.int64, device=device)
_g_logits = None
key_split = num_splits_heuristic(
batch_size * self.kv_manager.num_kv_heads,
max_seq_len=max_seqlen_k,
num_sms=torch.cuda.get_device_properties(device).multi_processor_count,
max_splits=12,
)
success, _g_batch_mapping = self.kv_manager.allocate_sequences(
list(range(batch_size)), [256] * batch_size
)
assert success
set_context(
is_prefill=False,
do_compression=False,
batch_mapping=_g_batch_mapping,
key_split=key_split,
)
# warmup
self.model.compute_logits(self.model(_g_input_ids, _g_positions))
dist.barrier()
decode_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(decode_graph):
_g_logits = self.model.compute_logits(
self.model(_g_input_ids, _g_positions)
)
graph_vars = {
"graph": decode_graph,
"input_ids": _g_input_ids,
"positions": _g_positions,
"batch_mapping": _g_batch_mapping,
"logits": _g_logits,
"key_split": key_split,
}
if batch_size not in self.captured_graphs:
self.captured_graphs[batch_size] = {}
self.min_captured_len[batch_size] = float("inf")
self.captured_graphs[batch_size][max_seqlen_k] = graph_vars
self.min_captured_len[batch_size] = min(
max_seqlen_k, self.min_captured_len[batch_size]
)
self.kv_manager.free_sequences(list(range(batch_size)))
def get_cuda_graph(self, batch_size: int, max_seqlen_k: int):
batch_size = next(x for x in self.captured_graphs.keys() if x >= batch_size)
batch_size_graphs = self.captured_graphs[batch_size]
# we want largest seq_len that is smaller than max_seqlen_k
best = self.min_captured_len[batch_size]
for seq_len in batch_size_graphs.keys():
if seq_len <= max_seqlen_k:
best = max(best, seq_len)
return batch_size_graphs[best]
import time
from typing import Iterable, List
from compactor_vllm.core.memory_manager import KVCacheManager
from compactor_vllm.utils.sequence import Sequence, SequenceStatus
from tqdm import tqdm
def cdiv(a, b):
"""ceiling division"""
return (a + b - 1) // b
class Scheduler:
"""
Simple sequence scheduler for prefill + decode with a paged KV cache.
The scheduler tracks three disjoint sets of sequence IDs:
* ``pending_sequence_ids`` – sequences that have not yet been started.
* ``active_sequence_ids`` – sequences currently running.
* ``finished_sequence_ids`` – sequences that have generated all tokens.
At prefill time, :meth:`get_prefill_batch` selects a subset of pending
sequences that can fit into the available KV cache and per-step token
budget, given the constraints from the associated :class:`KVCacheManager`.
The class also handles basic bookkeeping of sequence statuses.
Args:
:param all_sequences:
Iterable of :class:`Sequence` objects to be scheduled. Each
sequence must have a unique ``seq_id``.
:param kv_manager:
A :class:`KVCacheManager` instance that this scheduler will use
to determine whether additional batches can be scheduled.
:param use_tqdm:
If True, two progress bars are created:
* "Started Batches" – increments when a sequence moves from
pending to running.
* "Finished Batches" – increments when a sequence finishes.
"""
def __init__(
self,
all_sequences: Iterable[Sequence],
kv_manager: KVCacheManager,
*,
use_tqdm=False,
):
self.allseq_mapping: dict[int, Sequence] = {s.seq_id: s for s in all_sequences}
self.pending_sequence_ids: set[int] = set([s.seq_id for s in all_sequences])
self.active_sequence_ids: set[int] = set()
self.finished_sequence_ids: set[int] = set()
self.manager = kv_manager
self.use_tqdm = use_tqdm
self.start_time = time.perf_counter()
self.total_tokens_generated = 0
self.total_tokens_input = 0
self.pbar = None
if use_tqdm:
self.pbar = tqdm(
total=len(self.pending_sequence_ids),
desc="Completed Batches",
)
def get_prefill_batch(self) -> List[Sequence]:
"""
Select a batch of pending sequences to prefill under KV/memory constraints.
The selection is greedy over ``pending_sequence_ids`` in iteration order.
A sequence is added to the batch if:
* The sum of its prompt length and the total prompt tokens selected so
far does not exceed ``manager.max_batched_tokens``, and
* There is at least one free KV "batch slot" left
(``manager.num_free_batches``), and
* The total number of KV pages required by the sequence's prompt +
max_new_tokens does not exceed the remaining free pages.
Returns:
:return List[Sequence]:
The list of :class:`Sequence` objects chosen for prefill in
this step. The caller is responsible for marking them as
active via :meth:`add_running_sequence_ids`.
"""
total_tok, sequences = 0, []
num_free_batches, num_free_pages = (
self.manager.num_free_batches,
self.manager.num_free_pages,
)
for seq_id in self.pending_sequence_ids:
seq = self.allseq_mapping[seq_id]
prompt_length = seq.prompt_len
pages_needed = (
cdiv(
prompt_length + seq.sampling_params.max_new_tokens,
self.manager.page_size,
)
* self.manager.num_kv_heads
)
if (
prompt_length + total_tok <= self.manager.max_batched_tokens
and num_free_batches > 0
and pages_needed < num_free_pages
):
sequences.append(seq)
total_tok += prompt_length
num_free_pages -= pages_needed
num_free_batches -= 1
return sequences
def is_finished(self) -> bool:
"""
Check whether all sequences have completed.
"""
return (
len(self.pending_sequence_ids) == 0 and len(self.active_sequence_ids) == 0
)
def any_pending_sequences(self) -> bool:
"""
Check whether any sequences are still pending (not yet started).
"""
return len(self.pending_sequence_ids) != 0
def add_running_sequence_ids(
self, active_sequence_ids: Iterable[int], *, update_status: bool = False
):
"""
Mark a set of sequences as active / running. This moves sequence IDs
from ``pending_sequence_ids`` into ``active_sequence_ids``. Optionally,
it also updates the per-sequence status and progress bar.
Args:
:param active_sequence_ids:
Iterable of sequence IDs that have been scheduled for prefill
or decode and should now be considered running.
:param update_status:
If True, set each corresponding :class:`Sequence`'s
``status = SequenceStatus.RUNNING`` and increment the
"Started Batches" progress bar if ``use_tqdm`` is enabled.
"""
self.active_sequence_ids.update(active_sequence_ids)
self.pending_sequence_ids.difference_update(self.active_sequence_ids)
if update_status:
for seq_id in active_sequence_ids:
self.allseq_mapping[seq_id].status = SequenceStatus.RUNNING
self.total_tokens_input += self.allseq_mapping[seq_id].prompt_len
def get_finished_sequence_ids_from_unfinished(
self, unfinished_sequence_ids: Iterable[int]
) -> set[int]:
"""
Infer which active sequences have finished given the
unfinished set (for decode steps where the caller knows
which sequences are still generating but not necessarily
which have just completed).
Args:
:param unfinished_sequence_ids:
Iterable of sequence IDs that are still running
Returns:
:return set[int]:
The inferred set of sequence IDs that transitioned from active
to finished.
"""
return self.active_sequence_ids.difference(unfinished_sequence_ids)
def record_finished_sequence_ids(
self, finished_sequence_ids: Iterable[int], *, update_status: bool = False
):
"""
Record that a set of sequences has finished generation.
This moves IDs from ``active_sequence_ids`` into
``finished_sequence_ids``.
Args:
:param finished_sequence_ids:
Iterable of sequence IDs that have completed generation and
no longer require KV cache.
:param update_status:
If True, set each corresponding :class:`Sequence`'s
``status = SequenceStatus.FINISHED``
"""
self.active_sequence_ids.difference_update(finished_sequence_ids)
self.finished_sequence_ids.update(finished_sequence_ids)
if update_status:
for seq_id in finished_sequence_ids:
self.allseq_mapping[seq_id].status = SequenceStatus.FINISHED
if self.pbar is not None:
self.pbar.update(1)
def update_sequences(self, tokens: Iterable[int], seq_ids: Iterable[int]):
"""
Append newly generated tokens to their corresponding sequences.
Args:
:param tokens:
Iterable of generated token IDs, one per sequence.
:param seq_ids:
Iterable of sequence IDs aligned with ``tokens``.
"""
cur_time = time.perf_counter()
for tok, seq_id in zip(tokens, seq_ids):
self.allseq_mapping[seq_id].add_new_token(tok)
self.total_tokens_generated += 1
if self.pbar is not None:
self.pbar.set_description(
f"Throughput: {(self.total_tokens_generated + self.total_tokens_input) / (cur_time - self.start_time):.2f} tok/s"
)
def close(self):
if self.pbar is not None:
self.pbar.close()
def can_prefill_another_batch(self) -> bool:
return len(self.get_prefill_batch()) > 0
import heapq
import logging
from enum import Enum, auto
from typing import List, Optional, Union
import torch
from compactor_vllm.config.constants import RESERVED_BATCH
from compactor_vllm.kv_cache.write_page_table import scatter_to_page_table
logger = logging.getLogger(__name__)
def cdiv(a, b):
return (a + b - 1) // b
def next_multiple(a, b):
return cdiv(a, b) * b
class KVAllocationStatus(Enum):
EXCEEDS_MAX_SEQUENCE_LENGTH = auto()
EXCEEDS_CURRENTLY_AVAILABLE_PAGES = auto()
EXCEEDS_MAX_NUM_BATCHES = auto()
SUCCESS = auto()
class PagedKVCache(torch.nn.Module):
"""
Global paged KV cache.
This module manages:
* A global K/V backing buffer for all layers:
``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
where the first dimension indexes K vs V.
* A per-layer page table:
``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
mapping logical (batch, kv-head, logical_page) to a physical page ID
in the global K/V buffer.
* Per-layer, per-(batch, kv-head) logical sequence lengths
``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
the number of allocated pages ``bh_num_pages`` for each (layer, batch,
head).
* A page allocator implemented as a min-heap of free physical pages
per layer, plus free batch indices.
Pages are of fixed size ``page_size`` tokens.
Args:
:param num_layers:
Number of transformer layers that will use this cache.
:param max_logical_pages_per_head:
Maximum number of logical pages that can be assigned to a single
(batch, kv-head) pair.
:param num_pages:
Total number of physical pages available in the global cache per
layer. The global K/V buffers are of length
``num_pages * page_size`` along the token dimension.
:param page_size:
Number of tokens stored per page.
:param H_kv:
Number of KV heads per layer.
:param head_dim:
Head dimension for K/V.
:param max_num_batches:
Maximum number of concurrent batches / sequences supported. One
batch index is reserved for internal use (``RESERVED_BATCH``).
:param dtype:
Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
:param device:
Device on which to allocate the cache (string, torch.device, or
int; defaults to ``"cuda"``).
"""
def __init__(
self,
num_layers: int,
max_logical_pages_per_head: int,
num_pages: int,
page_size: int, # tokens per page
H_kv: int,
head_dim: int,
max_num_batches: int,
dtype: torch.dtype,
device: Union[str, torch.device, int] = "cuda",
):
super().__init__()
self.n_pages = num_pages
self.num_layers = num_layers
self.page_size: int = int(page_size)
self.H_kv = int(H_kv)
self.max_pages_per_head = max_logical_pages_per_head
max_num_batches += 1
self.max_num_batches = max_num_batches
self.head_dim = head_dim
cache_shape = (2, num_layers, num_pages * page_size, head_dim)
self.kv_cache = torch.empty(cache_shape, dtype=dtype, device=device)
self.page_table = torch.empty(
(num_layers, max_num_batches, H_kv, self.max_pages_per_head),
device=device,
dtype=torch.int32,
)
# Per-(batch, head) logical seq length (tokens)
self.bh_seq_lens = torch.zeros(
(num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
)
# self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
self.bh_num_pages = torch.zeros(
(num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
)
# Page allocator (min-heap of free physical pages)
self.free_pages: List[List[int]] = [
list(range(num_pages)) for _ in range(num_layers)
]
for free_pages in self.free_pages:
heapq.heapify(free_pages)
# batch zero is reserved
self.free_batches: List[int] = list(reversed(range(max_num_batches)))
self.free_batches.remove(RESERVED_BATCH)
# Record of physical page ids owned by a batch (for freeing)
self.pages_indices_per_batch: List[List[set[int]]] = [
[set() for _ in range(num_layers)] for _ in range(max_num_batches)
]
def new_batch(self) -> Optional[int]:
"""
Reserve a new batch slot.
A batch slot corresponds to a row in ``bh_seq_lens`` /
``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
heads. This method checks whether a free batch index is available, and
whether each layer has at least ``H_kv`` free pages remaining.
If both checks pass, it returns a batch index and removes it from
``free_batches``. Otherwise, it returns ``None``.
Returns:
:return Optional[int]:
Newly reserved batch index, or ``None`` if no capacity is
available.
"""
if self.free_batches and all([self.H_kv <= len(fp) for fp in self.free_pages]):
return self.free_batches.pop()
return None
def reserve_tokens(self, batch_index: int, add_tokens: int) -> KVAllocationStatus:
"""
Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
Args:
:param batch_index:
Batch index to reserve space for.
:param add_tokens:
Number of additional tokens to reserve capacity for.
All heads in this batch and all layers reserve
the same number of extra tokens.
Returns:
:return bool:
``True`` if the reservation succeeds; ``False`` otherwise .
"""
cur_bh_lens = self.bh_seq_lens[:, batch_index] # [L, H]
curr_pages = self.bh_num_pages[:, batch_index] # [L, H]
curr_cap_tokens = curr_pages * self.page_size # [L, H]
need_tokens = cur_bh_lens + add_tokens # [L, H]
if (need_tokens <= curr_cap_tokens).all():
return KVAllocationStatus.SUCCESS
missing_tokens = need_tokens - curr_cap_tokens
add_pages = cdiv(missing_tokens, self.page_size)
new_total_pages = curr_pages + add_pages
if (new_total_pages > self.max_pages_per_head).any():
return KVAllocationStatus.EXCEEDS_MAX_SEQUENCE_LENGTH
# CPU work
pages_per_layer_cpu = add_pages.sum(dim=-1).tolist()
new_phys_pages = []
for layer_index in range(self.num_layers):
if pages_per_layer_cpu[layer_index] > len(self.free_pages[layer_index]):
return KVAllocationStatus.EXCEEDS_CURRENTLY_AVAILABLE_PAGES
for layer_index in range(self.num_layers):
this_layer_pages = [
heapq.heappop(self.free_pages[layer_index])
for _ in range(pages_per_layer_cpu[layer_index])
]
self.pages_indices_per_batch[batch_index][layer_index] |= set(
this_layer_pages
)
new_phys_pages.extend(this_layer_pages)
new_phys_pages = torch.tensor(new_phys_pages, dtype=torch.int32, device="cuda")
scatter_to_page_table(
add_pages=add_pages,
new_phys_pages=new_phys_pages,
curr_pages=curr_pages,
page_table=self.page_table[:, batch_index],
max_pages_per_head=self.max_pages_per_head,
)
self.bh_num_pages[:, batch_index, :] = new_total_pages.to(
self.bh_num_pages.dtype
)
return KVAllocationStatus.SUCCESS
def reclaim_pages(
self,
batch_index: int,
future_reserve_tokens: int = 0,
):
"""
Reclaim unused pages for a single batch index. This shrinks the KV
allocation for the batch down to the minimum number of pages needed
to hold the current (plus optional future) sequence length.
Args:
:param batch_index:
Batch index whose pages should be compacted.
:param future_reserve_tokens:
Optional number of extra tokens to keep capacity for, beyond
the current sequence length. This can reduce churn when
sequences are expected to grow slightly in the near future.
Returns:
:return int:
Approximate number of bytes freed across both K and V.
"""
device = self.bh_seq_lens.device
L, B, H = self.bh_seq_lens.shape
assert 0 <= batch_index < B
seq = self.bh_seq_lens[:, batch_index, :] + future_reserve_tokens # [L, H]
alloc = self.bh_num_pages[:, batch_index, :] # [L, H]
pt = self.page_table[:, batch_index, :, :].reshape(-1) # [L, H, P]
# Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
used_pages = cdiv(seq, self.page_size)
used_pages = torch.minimum(used_pages, alloc)
# page indices [0..P-1], broadcasted over [L, H, P]
p = torch.arange(
self.max_pages_per_head, device=device, dtype=torch.int32
).view(1, 1, self.max_pages_per_head)
# allocated: p < alloc
alloc_mask = p < alloc.unsqueeze(-1) # [L, H, P]
# to free: allocated and p in [used_pages, alloc)
free_mask = alloc_mask & (p >= used_pages.unsqueeze(-1))
free_mask_flat = free_mask.view(-1) # [L*H*P]
if not free_mask_flat.any():
return 0
idx = free_mask_flat.nonzero(as_tuple=False).squeeze(
-1
) # indices of freed slots
# Freed physical page ids
freed_pages = pt[idx]
# Compute layer index for each freed slot:
# layout is [L, H, P] → flat index = ((l * H) + h) * P + p
freed_layers = (idx // (H * self.max_pages_per_head)).to(torch.int32)
freed_pages = freed_pages.tolist()
layer_mapping = freed_layers.tolist()
self.bh_num_pages[:, batch_index, :] = used_pages
for page, layer in zip(freed_pages, layer_mapping):
self.pages_indices_per_batch[batch_index][layer].remove(page)
heapq.heappush(self.free_pages[layer], page)
approximate_bytes_freed = (
len(freed_pages)
* (self.page_size * self.head_dim * self.kv_cache.element_size())
* 2
) # multiply for two for K + V
return approximate_bytes_freed
def _free_batch_layer(self, layer_index: int, batch_index: int) -> None:
"""
Free all pages belonging to batch_index and reset its metadata.
"""
# Return pages to the global heap
for phys in self.pages_indices_per_batch[batch_index][layer_index]:
heapq.heappush(self.free_pages[layer_index], int(phys))
self.pages_indices_per_batch[batch_index][layer_index] = set()
def free_batch(self, batch_index: int) -> None:
"""
Free all resources associated with a batch index.
Args:
:param batch_index:
Batch index to release. Must have been previously allocated
via :meth:`new_batch`.
"""
for layer in range(self.num_layers):
self._free_batch_layer(layer, batch_index)
self.bh_seq_lens[:, batch_index].zero_()
self.bh_num_pages[:, batch_index].zero_()
self.free_batches.append(batch_index)
def layer_slices(self, layer: int):
"""
Return layer-local views needed by the attention module.
For a given ``layer`` index, this method returns the slices of the
global K/V cache, page table, and per-(batch, head) sequence lengths
corresponding to that layer.
Args:
:param layer:
Layer index ``l`` in ``[0, num_layers)``.
Returns:
:return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
``(k, v, pt, bh)`` as described above.
"""
assert 0 <= layer < self.num_layers
k = self.kv_cache[0, layer]
v = self.kv_cache[1, layer]
pt = self.page_table[layer]
bh = self.bh_seq_lens[layer]
return k, v, pt, bh
import torch
import triton
import triton.language as tl
from compactor_vllm.config.constants import (
TRITON_RESERVED_BATCH as _TRITON_RESERVED_BATCH,
)
@triton.jit
def _prefill_store_topk_kv_kernel(
key,
value, # [N_total, H, D] (D stride assumed 1)
batch_mapping, # [B] int32 (local b -> true batch)
num_tokens_to_retain, # [B] int32
indices_topk, # [B, MAX_SEL] int32 (across all heads)
# Lengths & page table:
bh_lens, # [B, H] int32 (contiguous)
page_table, # [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
k_cache,
v_cache, # [N_PAGES * PAGE_SIZE, D]
sk_n,
sk_h, # strides for key,value. D stride assumed 1
sv_n,
sv_h,
# Runtime ints
MAX_SEL, # num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
K_TILE: tl.constexpr, # how many selected tokens each program processes
TRITON_RESERVED_BATCH: tl.constexpr,
):
b_local = tl.program_id(0)
tile_id = tl.program_id(1)
offs = tl.arange(0, D)
# how many tokens we actually keep for this batch
k_total = tl.load(num_tokens_to_retain + b_local)
if k_total == 0:
return
# map to true batch row in the page table
b_true = tl.load(batch_mapping + b_local)
if b_true == TRITON_RESERVED_BATCH:
return
base = tile_id * K_TILE
# process up to K_TILE tokens
for j in tl.range(0, K_TILE):
sel_idx = base + j
if sel_idx < k_total and sel_idx < MAX_SEL:
# flattened selection: sel = token * H + head
sel = tl.load(indices_topk + b_local * MAX_SEL + sel_idx)
tok = sel // HKV
head = sel - (tok * HKV)
# atomically reserve one position in (b_local, hed)
# i.e the KV cache is scrambled when storing
len_ptr = bh_lens + b_local * HKV + head
pos = tl.atomic_add(len_ptr, 1) # old length (int32)
lp = pos // PAGE_SIZE
off = pos - lp * PAGE_SIZE
# translate logical page to physical page
pt_base = (b_true * HKV + head) * N_LOGICAL_PAGES_MAX
phys = tl.load(page_table + pt_base + lp).to(tl.int64)
# destination row and element offset
dst_row = phys * PAGE_SIZE + off
dst_off = dst_row * D + offs
# load one vector from [N_total, H, D]
k_src = key + tok * sk_n + head * sk_h + offs
v_src = value + tok * sv_n + head * sv_h + offs
tl.store(
k_cache + dst_off,
tl.load(k_src, cache_modifier=".cv", eviction_policy="evict_first"),
eviction_policy="evict_first",
)
tl.store(
v_cache + dst_off,
tl.load(v_src, cache_modifier=".cv", eviction_policy="evict_first"),
eviction_policy="evict_first",
)
def prefill_store_topk_kv(
*,
new_keys: torch.Tensor, # [N_total, H, D]
new_vals: torch.Tensor, # [N_total, H, D]
indices_topk: torch.Tensor, # [B, MAX_SEL] int32 (global flattened token*H + head)
num_tokens_to_retain: torch.Tensor, # [B] int32
page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE: int,
PAD_TO_PAGE_SIZE: bool = True,
cu_seqlens_k: torch.Tensor | None = None,
K_TILE: int = 16,
TRITON_RESERVED_BATCH: int = None,
):
assert new_keys.shape == new_vals.shape
N_total, H, D = new_keys.shape
B = indices_topk.shape[0]
assert page_table.shape[1] == H
assert bh_lens.shape == (B, H)
assert new_keys.device == k_cache.device == v_cache.device
assert page_table.is_contiguous(), "page table must be contiguous."
assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
assert k_cache.is_contiguous() and v_cache.is_contiguous()
assert new_keys.stride(-1) == 1 and new_vals.stride(-1) == 1, (
"new_keys/new_vals last dim must be contiguous."
)
assert (D & (D - 1)) == 0, "D must be a power of 2"
page_table = page_table.to(torch.int32)
bh_lens = bh_lens.to(torch.int32)
batch_mapping = batch_mapping.to(torch.int32)
indices_topk = indices_topk.to(torch.int32)
num_tokens_to_retain = num_tokens_to_retain.to(torch.int32)
# strides (elements) for [N_total, H, D]
sk_n, sk_h, _ = new_keys.stride()
sv_n, sv_h, _ = new_vals.stride()
# tile second grid dim
MAX_SEL = indices_topk.shape[-1]
N_TILES = (MAX_SEL + K_TILE - 1) // K_TILE
grid = (B, max(1, N_TILES))
if TRITON_RESERVED_BATCH is None:
TRITON_RESERVED_BATCH = _TRITON_RESERVED_BATCH
_prefill_store_topk_kv_kernel[grid](
key=new_keys,
value=new_vals,
batch_mapping=batch_mapping,
num_tokens_to_retain=num_tokens_to_retain,
indices_topk=indices_topk,
bh_lens=bh_lens,
page_table=page_table,
k_cache=k_cache,
v_cache=v_cache,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
MAX_SEL=int(MAX_SEL),
HKV=H,
N_LOGICAL_PAGES_MAX=page_table.shape[2],
D=D,
PAGE_SIZE=PAGE_SIZE,
K_TILE=K_TILE,
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
)
if PAD_TO_PAGE_SIZE:
assert cu_seqlens_k is not None
assert indices_topk.is_contiguous()
assert page_table.is_contiguous()
_prefill_store_topk_pad_kernel[(B, H)](
key=new_keys,
value=new_vals,
batch_mapping=batch_mapping,
num_tokens_to_retain=num_tokens_to_retain,
indices=indices_topk,
local_lens=bh_lens,
page_table_flat=page_table,
k_cache=k_cache,
v_cache=v_cache,
cu_seqlens_k=cu_seqlens_k,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
MAX_SEL=int(MAX_SEL),
H=H, # type: ignore
N_LOGICAL_PAGES_MAX=page_table.shape[2], # type: ignore
D=D, # type: ignore
PAGE_SIZE=PAGE_SIZE, # type: ignore
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
)
@triton.jit
def _prefill_store_topk_pad_kernel(
key, # [N_total, H, D]
value, # [N_total, H, D]
batch_mapping, # [B] int32 (local b -> true batch)
num_tokens_to_retain, # [B] int32
indices, # [B, MAX_SEL] int32 (across all heads)
local_lens, # [B, H] int32 (contiguous)
page_table_flat, # [B_total*H*N_LOGICAL_PAGES_MAX] int32
k_cache,
v_cache, # [N_PAGES*PAGE_SIZE, D]
cu_seqlens_k,
sk_n,
sk_h,
sv_n,
sv_h,
MAX_SEL,
# Constexprs
H: tl.constexpr, # number of KV heads
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
TRITON_RESERVED_BATCH: tl.constexpr,
):
b_local = tl.program_id(0)
h = tl.program_id(1)
offs_d = tl.arange(0, D)
L = tl.load(local_lens + b_local * H + h)
modulo_page_size = L - (L // PAGE_SIZE) * PAGE_SIZE
if modulo_page_size == 0:
return
need = PAGE_SIZE - modulo_page_size
b_true = tl.load(batch_mapping + b_local)
if b_true == TRITON_RESERVED_BATCH:
return
pt_base = (b_true * H + h) * N_LOGICAL_PAGES_MAX
written_tokens = 0
idx = tl.load(num_tokens_to_retain + b_local)
this_batch_ctx_len = tl.load(cu_seqlens_k + b_local + 1) - tl.load(
cu_seqlens_k + b_local
)
max_additional = this_batch_ctx_len - L
while (written_tokens < need and idx < MAX_SEL) and (
written_tokens < max_additional
):
# candidate head
cand_idx = tl.load(indices + b_local * MAX_SEL + idx)
cand_h = cand_idx % H
if cand_h == h:
tok = cand_idx // H
pos = L + written_tokens
lp = pos // PAGE_SIZE
off = pos - lp * PAGE_SIZE
phys = tl.load(page_table_flat + pt_base + lp).to(tl.int32)
dst_row = phys * PAGE_SIZE + off
dst_off = dst_row.to(tl.int64) * D + offs_d
k_src = key + tok * sk_n + h * sk_h + offs_d
v_src = value + tok * sv_n + h * sv_h + offs_d
tl.store(
k_cache + dst_off,
tl.load(k_src),
)
tl.store(
v_cache + dst_off,
tl.load(v_src),
)
written_tokens += 1
idx += 1
tl.store(local_lens + b_local * H + h, L + written_tokens)
@triton.jit
def _prefill_store_all_kv_kernel(
key,
value, # [N, H, D] (D contiguous)
cu_seqlens_k, # [B + 1] int32
batch_mapping, # [B] int32 (local b -> true batch index)
bh_lens, # [B * HKV] int32 (UPDATED)
pt_flat, # [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
k_cache,
v_cache, # [N_PAGES * PAGE_SIZE, D]
# source strides (elements)
sk_n,
sk_h,
sv_n,
sv_h,
# constexpr
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
K_TILE: tl.constexpr, # number of (token, head) pairs processed per program
):
pid_b = tl.program_id(0)
pid_blk = tl.program_id(1)
start = tl.load(cu_seqlens_k + pid_b)
end = tl.load(cu_seqlens_k + pid_b + 1)
num_toks_this_batch = end - start
if num_toks_this_batch <= 0:
return
total_elems = num_toks_this_batch * HKV
# base linear index in (token, head) grid for this program
base = pid_blk * K_TILE
offs_d = tl.arange(0, D)
# Iterate K_TILE elements in this tile
for i in tl.range(0, K_TILE):
idx = base + i
if idx < total_elems:
# map linear idx -> (t, h)
t = idx // HKV
h = idx - t * HKV
len_idx = pid_b * HKV + h
L0 = tl.load(bh_lens + len_idx)
token_idx_in_cache = L0 + t
lp = token_idx_in_cache // PAGE_SIZE # logical page
off_in_pg = token_idx_in_cache - lp * PAGE_SIZE # pos in page
# physical page
b_true = tl.load(batch_mapping + pid_b).to(tl.int32)
pt_base = (b_true * HKV + h) * N_LOGICAL_PAGES_MAX
phys = tl.load(pt_flat + pt_base + lp).to(tl.int64)
row = phys * PAGE_SIZE + off_in_pg
dst_off = row * D + offs_d
n_global = (start + t).to(tl.int64)
# Use strides for non-contiguous [N, H, D] (D stride == 1)
k_src = key + n_global * sk_n + h * sk_h + offs_d
v_src = value + n_global * sv_n + h * sv_h + offs_d
tl.store(k_cache + dst_off, tl.load(k_src))
tl.store(v_cache + dst_off, tl.load(v_src))
def prefill_store_all_kv(
*,
new_keys: torch.Tensor,
new_values: torch.Tensor, # [N, H_kv, D]
cu_seqlens_k: torch.Tensor, # [B + 1] int32
max_seqlen_k: int,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
page_table: torch.Tensor, # [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
bh_lens: torch.Tensor, # [B, H_kv] int32 (UPDATED)
batch_mapping: torch.Tensor, # [B] int32 (local->true)
PAGE_SIZE: int,
K_TILE: int = 32, # how many (token, head) pairs per program
):
assert new_keys.stride(-1) == 1 and new_values.stride(-1) == 1, (
"last dim must be contiguous"
)
assert page_table.is_contiguous(), "page table must be contiguous"
assert bh_lens.is_contiguous(), "bh_lens must be contiguous"
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous"
assert k_cache.is_contiguous() and v_cache.is_contiguous()
N, HKV, D = new_keys.shape
B = batch_mapping.shape[0]
assert (D & (D - 1)) == 0, "D must be a power of 2"
sk_n, sk_h, _ = new_keys.stride()
sv_n, sv_h, _ = new_values.stride()
n_tiles = (max_seqlen_k * HKV + K_TILE - 1) // K_TILE
grid = (B, n_tiles)
_prefill_store_all_kv_kernel[grid](
new_keys,
new_values,
cu_seqlens_k,
batch_mapping,
bh_lens,
page_table,
k_cache,
v_cache,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
HKV=HKV,
N_LOGICAL_PAGES_MAX=page_table.shape[-1],
D=D,
PAGE_SIZE=PAGE_SIZE,
K_TILE=K_TILE,
)
bh_lens += cu_seqlens_k.diff()[:, None]
@triton.jit
def _decode_store_kv_kernel(
key,
value,
batch_mapping, # [B] int32
bh_lens, # [B*HKV] int32
page_table, # [B_total*HKV*N_LOGICAL_PAGES_MAX]
k_cache,
v_cache, # [N_PAGES*PAGE_SIZE, D]
sk_b,
sk_h,
sv_b,
sv_h,
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
TRITON_RESERVED_BATCH: tl.constexpr,
):
pid_b = tl.program_id(0)
h = tl.program_id(1)
mapped_b = tl.load(batch_mapping + pid_b)
if mapped_b == TRITON_RESERVED_BATCH:
return
offs_d = tl.arange(0, D)
length = tl.load(bh_lens + pid_b * HKV + h)
logical_page = length // PAGE_SIZE
internal_offset = length - logical_page * PAGE_SIZE
pt_base = (mapped_b * HKV + h) * N_LOGICAL_PAGES_MAX
physical_page = tl.load(page_table + pt_base + logical_page).to(tl.int64)
dst_row = physical_page * PAGE_SIZE + internal_offset
# Source addressing using strides (D stride == 1)
k_src = key + pid_b * sk_b + h * sk_h + offs_d
v_src = value + pid_b * sv_b + h * sv_h + offs_d
dst_off = dst_row * D + offs_d
tl.store(k_cache + dst_off, tl.load(k_src))
tl.store(v_cache + dst_off, tl.load(v_src))
tl.store(bh_lens + pid_b * HKV + h, length + 1)
def decode_store_kv(
*,
key: torch.Tensor, # [B, HKV, D]
value: torch.Tensor, # [B, HKV, D]
batch_mapping: torch.Tensor, # [B] int32
bh_lens: torch.Tensor, # [B, HKV] or flattened [B*HKV] int32
page_table: torch.Tensor, # [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
k_cache: torch.Tensor,
v_cache: torch.Tensor, # [N_PAGES*PAGE_SIZE, D]
PAGE_SIZE: int,
TRITON_RESERVED_BATCH: int = None,
):
assert key.shape == value.shape and key.ndim == 3, "key/value must be [B, HKV, D]"
B, HKV, D = key.shape
assert key.stride(-1) == 1 and value.stride(-1) == 1, (
"key/value last dim must be contiguous."
)
assert page_table.is_contiguous(), "page table must be contiguous."
assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
assert k_cache.is_contiguous() and v_cache.is_contiguous()
assert (D & (D - 1)) == 0, "D must be a power of 2"
sk_b, sk_h, _ = key.stride()
sv_b, sv_h, _ = value.stride()
grid = (
int(batch_mapping.shape[0]),
HKV,
)
_decode_store_kv_kernel[grid](
key=key,
value=value,
batch_mapping=batch_mapping,
bh_lens=bh_lens,
page_table=page_table,
k_cache=k_cache,
v_cache=v_cache,
sk_b=sk_b,
sk_h=sk_h,
sv_b=sv_b,
sv_h=sv_h,
HKV=HKV,
N_LOGICAL_PAGES_MAX=page_table.shape[2],
D=D,
PAGE_SIZE=PAGE_SIZE,
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH
if TRITON_RESERVED_BATCH is not None
else _TRITON_RESERVED_BATCH,
)
import torch
import triton
import triton.language as tl
def scatter_to_page_table(
add_pages: torch.Tensor, # [L, H] int32
new_phys_pages: torch.Tensor, # [N]
curr_pages: torch.Tensor, # [L, H] int32
page_table: torch.Tensor, # [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
max_pages_per_head: int,
):
"""
Append newly allocated physical pages into a layered page table via Triton.
For each (layer ``l``, head ``h``):
Args:
:param add_pages:
Tensor of shape ``[L, H]`` (int32) indicating how many pages to
append for each (layer, head).
:param new_phys_pages:
1D tensor of shape ``[N]`` (int32) containing physical page IDs
for all (layer, head) pairs, concatenated in row-major (L, H)
order. ``N`` must equal ``add_pages.sum()``.
:param curr_pages:
Tensor of shape ``[L, H]`` (int32) with the current logical page
counts per (layer, head) before this update.
:param page_table:
Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
the logical to physical page mapping. The last dimension is
logically indexed as logical_page ∈ [0, max_pages_per_head).
:param max_pages_per_head:
Maximum number of logical pages permitted per (layer, head). The
kernel skips writes beyond this bound.
Returns:
None. The function updates ``page_table`` in-place.
"""
L, H = add_pages.shape
if L == 0 or H == 0:
return
add_flat = add_pages.to(torch.int32).contiguous().view(-1)
curr_flat = curr_pages.to(torch.int32).contiguous().view(-1)
cum_page_heads = torch.empty(L * H + 1, device="cuda", dtype=torch.int32)
cum_page_heads[0] = 0
torch.cumsum(add_flat, 0, out=cum_page_heads[1:])
stride_pl, stride_ph, stride_pp = page_table.stride()
grid = (L, H)
_scatter_pages_kernel_lh[grid](
add_flat,
cum_page_heads,
new_phys_pages,
curr_flat,
page_table,
stride_pl,
stride_ph,
stride_pp,
L=L,
H=H,
max_pages_per_head=max_pages_per_head,
)
@triton.jit
def _scatter_pages_kernel_lh(
add_pages, # int32 [L*H]
cum_page_heads, # int32 [L*H], base offset in flat_new_phys per (l,h)
flat_new_phys, # int32 [total_pages]
curr_pages, # int32 [L*H], existing logical pages per (l,h)
page_table_ptr, # int32* base pointer to page_table
stride_pl, # int, stride for layer dim
stride_ph, # int, stride for head dim
stride_pp, # int, stride for page dim
L: tl.constexpr,
H: tl.constexpr,
max_pages_per_head: tl.constexpr,
):
layer_idx = tl.program_id(0)
h = tl.program_id(1)
if layer_idx >= L or h >= H:
return
lh = layer_idx * H + h
ap = tl.load(add_pages + lh)
if ap <= 0:
return
base = tl.load(cum_page_heads + lh)
cp = tl.load(curr_pages + lh)
# Append ap pages: logical pages [cp .. cp+ap)
for i in tl.range(0, ap):
phys = tl.load(flat_new_phys + base + i)
lp = cp + i
if lp < max_pages_per_head:
offset = layer_idx * stride_pl + h * stride_ph + lp * stride_pp
tl.store(page_table_ptr + offset, phys)
# TODO: write reclaim kernel
@triton.jit
def reclaim_page_kernel():
pass
def reclaim_pages(
batch_index: int,
bh_seq_lens: torch.Tensor,
bh_num_pages: torch.Tensor,
page_table: torch.Tensor,
):
pass
import torch
import torch.nn.functional as F
from torch import nn
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
# @torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
return F.silu(x) * y
from typing import Optional
import torch
from compactor_vllm.attention.sparse_decode_kernel import head_sparse_decode_attention
from compactor_vllm.attention.sparse_varlen_kernel import (
causal_sparse_varlen_with_cache,
)
from compactor_vllm.compression.common import extract_and_store_top_kv
from compactor_vllm.config.engine_config import AttentionBackend
from compactor_vllm.kv_cache.store_kv_cache import decode_store_kv, prefill_store_all_kv
from compactor_vllm.utils.context import Context, get_context
from compactor_vllm.utils.helpers import maybe_execute_in_stream
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from torch import nn
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads: int = num_heads
self.head_dim = head_dim
self.scale: float = scale
self.num_kv_heads = int(num_kv_heads)
self.k_cache: Optional[torch.Tensor] = None
self.v_cache: Optional[torch.Tensor] = None
self.page_table: Optional[torch.Tensor] = None
self.bh_seq_lens: Optional[torch.Tensor] = None
self.page_size: Optional[int] = None
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scores: Optional[torch.Tensor] = None,
):
context: Context = get_context()
batch_mapping = context.batch_mapping
seq_lens = (
None
if self.bh_seq_lens is None
else self.bh_seq_lens.index_select(0, batch_mapping).contiguous()
)
if context.is_prefill:
seq_lens_copy = seq_lens.clone() if seq_lens is not None else None
if (
self.k_cache is not None
and context.do_compression
and scores is not None
):
compression_context = context.compression_context
assert scores is not None
assert compression_context is not None
maybe_execute_in_stream(
extract_and_store_top_kv,
scores=scores,
cu_seqlens_k=context.cu_seqlens_k,
max_k_len=context.max_seqlen_k,
top_k=compression_context.max_tokens_to_retain,
H=int(self.num_kv_heads),
new_keys=k,
new_vals=v,
num_tokens_to_retain=compression_context.batch_tokens_to_retain,
page_table=self.page_table,
batch_mapping=batch_mapping,
bh_lens=seq_lens,
k_cache=self.k_cache,
v_cache=self.v_cache,
PAGE_SIZE=self.page_size,
PAD_TO_PAGE_SIZE=True,
STORE_STREAM=context.STORE_STREAM,
)
elif self.k_cache is not None:
maybe_execute_in_stream(
prefill_store_all_kv,
new_keys=k,
new_values=v,
cu_seqlens_k=context.cu_seqlens_k,
max_seqlen_k=context.max_seqlen_k,
k_cache=self.k_cache,
v_cache=self.v_cache,
page_table=self.page_table,
bh_lens=seq_lens,
batch_mapping=batch_mapping,
PAGE_SIZE=self.page_size,
STORE_STREAM=context.STORE_STREAM,
)
# No compression: FA varlen on q,k,v (matches HF). Compressed: Triton reads paged KV.
use_flash_prefill = context.attention_backend == AttentionBackend.FLASH_ATTENTION or (
context.attention_backend == AttentionBackend.COMPACTOR_TRITON
and not context.do_compression
)
if use_flash_prefill:
o = flash_attn_varlen_func(
q,
k,
v,
max_seqlen_q=context.max_seqlen_q,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k,
cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale,
causal=True,
)
elif context.attention_backend == AttentionBackend.COMPACTOR_TRITON:
# Top-k KV writes on STORE_STREAM; Triton prefill must see finished writes.
if context.do_compression and context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
o = causal_sparse_varlen_with_cache(
q,
k,
v,
self.k_cache,
self.v_cache,
seq_lens_bh=seq_lens_copy,
global_page_table=self.page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_q=context.max_seqlen_q,
max_seqlen_k_cache=context.max_bh_len,
HKV=int(self.num_kv_heads),
PAGE_SIZE=self.page_size,
sm_scale=self.scale,
)
else:
raise NotImplementedError
else:
assert self.k_cache is not None, "KV Cache must be initialized for decoding"
decode_store_kv(
key=k,
value=v,
batch_mapping=batch_mapping,
bh_lens=seq_lens,
page_table=self.page_table,
k_cache=self.k_cache,
v_cache=self.v_cache,
PAGE_SIZE=self.page_size,
)
o = head_sparse_decode_attention(
q,
self.k_cache,
self.v_cache,
seq_lens,
self.page_table,
batch_mapping,
int(self.num_kv_heads),
self.page_size,
self.scale,
key_split=context.key_split,
)
if self.bh_seq_lens is not None:
longbm = batch_mapping.to(torch.long)
maybe_execute_in_stream(
self.bh_seq_lens.index_copy_,
0,
longbm,
seq_lens,
STORE_STREAM=context.STORE_STREAM if context.is_prefill else None,
)
return o
import torch
import torch.distributed as dist
import torch.nn.functional as F
from compactor_vllm.utils.context import get_context
from torch import nn
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.weight = nn.Parameter(
torch.empty(self.num_embeddings_per_partition, embedding_dim)
)
self.weight.weight_loader = self.weight_loader
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor):
if self.tp_size > 1:
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask.unsqueeze(1) * y
dist.all_reduce(y)
return y
class ParallelLMHead(VocabParallelEmbedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
):
assert not bias
super().__init__(num_embeddings, embedding_dim)
def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight)
if self.tp_size > 1:
all_logits = (
[torch.empty_like(logits) for _ in range(self.tp_size)]
if self.tp_rank == 0
else None
)
dist.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits
import torch
from torch import nn
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
# @torch.compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x
# @torch.compile
def add_rms_forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x, residual
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
return self.add_rms_forward(x, residual)
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
def divide(numerator, denominator):
assert numerator % denominator == 0
return numerator // denominator
class LinearBase(nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class ColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class MergedColumnParallelLinear(ColumnParallelLinear):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = False,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int
):
param_data = param.data
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
):
tp_size = dist.get_world_size()
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str
):
param_data = param.data
assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q":
shard_size = self.num_heads * self.head_size
shard_offset = 0
elif loaded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size
else:
shard_size = self.num_kv_heads * self.head_size
shard_offset = (
self.num_heads * self.head_size + self.num_kv_heads * self.head_size
)
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = dist.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist.all_reduce(y)
return y
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment