Unverified Commit 4fff1ec1 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Deterministic Mode: Add 1-stage triton kernel for prefill (#11147)


Co-authored-by: default avatarMinglei Zhu <mingleizhu1122@gmail.com>
Co-authored-by: default avatarBinyao Jiang <bijiang@linkedin.com>
parent 7a020e0f
......@@ -64,13 +64,19 @@ class TritonAttnBackend(AttentionBackend):
decode_attention_fwd,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
build_unified_kv_indices,
extend_attention_fwd,
extend_attention_fwd_unified,
)
super().__init__()
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
self.extend_attention_fwd_unified = torch.compiler.disable(
extend_attention_fwd_unified
)
self.build_unified_kv_indices = torch.compiler.disable(build_unified_kv_indices)
# Parse args
self.skip_prefill = skip_prefill
......@@ -794,6 +800,7 @@ class TritonAttnBackend(AttentionBackend):
else:
o = torch.empty_like(q)
# Save KV cache first (must do this before unified kernel)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
......@@ -805,6 +812,13 @@ class TritonAttnBackend(AttentionBackend):
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False
# Deterministic mode: use unified 1-stage kernel
if self.enable_deterministic:
return self._forward_extend_unified(
q, o, layer, forward_batch, causal, logits_soft_cap, sinks
)
# Normal mode: use original 2-stage kernel
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
sliding_window_size = (
layer.sliding_window_size
......@@ -841,6 +855,127 @@ class TritonAttnBackend(AttentionBackend):
)
return o
def _forward_extend_unified(
self,
q: torch.Tensor,
o: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
causal: bool,
logits_soft_cap: float,
sinks: Optional[torch.Tensor],
):
"""
Unified 1-stage extend attention for deterministic inference.
Both prefix and extend KV are accessed through unified kv_indices.
"""
bs = forward_batch.batch_size
# Determine sliding window settings
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
sliding_window_size = layer.sliding_window_size
# Note: for unified kernel, we use full kv_indptr (not window)
prefix_kv_indptr = self.forward_metadata.window_kv_indptr
prefix_kv_indices = self.forward_metadata.window_kv_indices
# Compute window start positions (absolute position of first key in window)
# window_start_pos = seq_len - window_len
window_kv_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
# Handle TARGET_VERIFY mode where extend_prefix_lens might not be set
if forward_batch.extend_prefix_lens is not None:
window_start_pos = (
forward_batch.extend_prefix_lens[:bs] - window_kv_lens
)
else:
# Infer from spec_info: prefix_len = seq_len - draft_token_num
if forward_batch.spec_info is not None and hasattr(
forward_batch.spec_info, "draft_token_num"
):
extend_prefix_lens = (
forward_batch.seq_lens[:bs]
- forward_batch.spec_info.draft_token_num
)
window_start_pos = extend_prefix_lens - window_kv_lens
else:
window_start_pos = None
else:
sliding_window_size = -1
prefix_kv_indptr = self.forward_metadata.kv_indptr
prefix_kv_indices = self.forward_metadata.kv_indices
window_start_pos = None
# Build unified kv_indices using fused Triton kernel
extend_kv_indices = forward_batch.out_cache_loc
# Handle cases where extend_seq_lens or extend_start_loc might not be set
# In speculative decoding, we can infer these from spec_info or compute them
if forward_batch.extend_seq_lens is None:
# TARGET_VERIFY mode: infer extend_seq_lens from spec_info
if forward_batch.spec_info is not None and hasattr(
forward_batch.spec_info, "draft_token_num"
):
draft_token_num = forward_batch.spec_info.draft_token_num
extend_seq_lens = torch.full(
(bs,), draft_token_num, dtype=torch.int32, device=self.device
)
else:
raise RuntimeError(
"extend_seq_lens is None but cannot infer from spec_info. "
"This should not happen in TARGET_VERIFY mode."
)
else:
extend_seq_lens = forward_batch.extend_seq_lens
# Check extend_start_loc separately - it might be None even when extend_seq_lens is set
if forward_batch.extend_start_loc is None:
# Compute extend_start_loc from extend_seq_lens
# extend_start_loc[i] = sum(extend_seq_lens[0:i])
extend_start_loc = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=self.device),
torch.cumsum(extend_seq_lens[:-1], dim=0),
]
)
else:
extend_start_loc = forward_batch.extend_start_loc
unified_kv_indptr, unified_kv_indices, prefix_lens = (
self.build_unified_kv_indices(
prefix_kv_indptr,
prefix_kv_indices,
extend_start_loc,
extend_seq_lens,
extend_kv_indices,
bs,
)
)
# Convert prefix_lens to int32 for the kernel
prefix_lens = prefix_lens.to(torch.int32)
# Call unified kernel
self.extend_attention_fwd_unified(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
self.forward_metadata.qo_indptr,
unified_kv_indptr,
unified_kv_indices,
prefix_lens,
self.forward_metadata.max_extend_len,
custom_mask=self.forward_metadata.custom_mask,
mask_indptr=self.forward_metadata.mask_indptr,
sm_scale=layer.scaling,
logit_cap=logits_soft_cap,
is_causal=causal,
sliding_window_size=sliding_window_size,
sinks=sinks,
window_start_pos=window_start_pos,
xai_temperature_len=layer.xai_temperature_len,
)
return o
def forward_decode(
self,
q: torch.Tensor,
......
......@@ -32,12 +32,182 @@ if _is_cuda:
_is_hip = is_hip()
def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
"""
Get block sizes and configuration for extend attention kernels.
Args:
Lq: Query head dimension
Lv: Value head dimension
Returns:
tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps)
"""
# Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension
if Lq == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lq == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
elif Lq == 192:
BLOCK_DMODEL = 128
BLOCK_DPE = 64
else:
BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
# Determine BLOCK_M, BLOCK_N, and num_warps based on hardware
if _is_hip:
BLOCK_M, BLOCK_N = (64, 64)
num_warps = 4
else:
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
# Hopper architecture (H100, etc.)
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
# Ampere architecture (A100, etc.)
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
if Lq <= 128:
BLOCK_M, BLOCK_N = (64, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 32)
else:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
# Older architectures
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
num_warps = 4 if Lq <= 64 else 8
return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _copy_unified_indices_kernel(
# Input buffers
prefix_kv_indptr,
prefix_kv_indices,
extend_start_loc,
extend_seq_lens,
extend_kv_indices,
unified_kv_indptr,
# Output buffer
unified_kv_indices,
# Size
bs,
):
"""
Triton kernel to copy indices to unified buffer (parallel per sequence).
Each thread block processes one sequence with vectorized loads/stores.
"""
pid = tl.program_id(0)
if pid >= bs:
return
# Load sequence info
prefix_start = tl.load(prefix_kv_indptr + pid)
prefix_end = tl.load(prefix_kv_indptr + pid + 1)
extend_start = tl.load(extend_start_loc + pid)
extend_len = tl.load(extend_seq_lens + pid)
prefix_len = prefix_end - prefix_start
unified_start = tl.load(unified_kv_indptr + pid)
# Copy indices in vectorized chunks
BLOCK_SIZE: tl.constexpr = 128
# Process prefix indices
for block_start in range(0, prefix_len, BLOCK_SIZE):
offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = offs < prefix_len
src_idx = prefix_start + offs
dst_idx = unified_start + offs
vals = tl.load(prefix_kv_indices + src_idx, mask=mask, other=0)
tl.store(unified_kv_indices + dst_idx, vals, mask=mask)
# Process extend indices
for block_start in range(0, extend_len, BLOCK_SIZE):
offs = block_start + tl.arange(0, BLOCK_SIZE)
mask = offs < extend_len
src_idx = extend_start + offs
dst_idx = unified_start + prefix_len + offs
vals = tl.load(extend_kv_indices + src_idx, mask=mask, other=0)
tl.store(unified_kv_indices + dst_idx, vals, mask=mask)
def build_unified_kv_indices(
prefix_kv_indptr: torch.Tensor,
prefix_kv_indices: torch.Tensor,
extend_start_loc: torch.Tensor,
extend_seq_lens: torch.Tensor,
extend_kv_indices: torch.Tensor,
bs: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Build unified KV indices efficiently:
- Use PyTorch's optimized cumsum (NVIDIA CUB) for indptr
- Use Triton kernel for parallel index copying
Returns:
(unified_kv_indptr, unified_kv_indices, prefix_lens)
"""
device = prefix_kv_indptr.device
prefix_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
# Create unified_kv_indptr avoiding direct assignment (for CUDA graph compatibility)
unified_lens = prefix_lens + extend_seq_lens[:bs]
unified_kv_indptr = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=device),
torch.cumsum(unified_lens, dim=0),
]
)
max_unified_len = len(prefix_kv_indices) + len(extend_kv_indices)
unified_kv_indices = torch.empty(max_unified_len, dtype=torch.int64, device=device)
# Launch Triton kernel for parallel index copying
_copy_unified_indices_kernel[(bs,)](
prefix_kv_indptr,
prefix_kv_indices,
extend_start_loc,
extend_seq_lens,
extend_kv_indices,
unified_kv_indptr,
unified_kv_indices,
bs,
)
return unified_kv_indptr, unified_kv_indices, prefix_lens
@triton.jit
def _fwd_kernel(
Q_Extend,
......@@ -402,50 +572,10 @@ def extend_attention_fwd(
v_extend.shape[-1],
)
if Lq == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lq == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
elif Lq == 192:
BLOCK_DMODEL = 128
BLOCK_DPE = 64
else:
BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
if _is_hip:
BLOCK_M, BLOCK_N = (64, 64)
num_warps = 4
else:
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
if Lq <= 128:
BLOCK_M, BLOCK_N = (64, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 32)
else:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
num_warps = 4 if Lk <= 64 else 8
# Get block sizes and configuration
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
_get_block_sizes_for_extend_attention(Lq, Lv)
)
sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
......@@ -548,3 +678,368 @@ def redundant_attention(
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
pt += cur_seq_len_extend
@triton.jit
def _fwd_kernel_unified(
Q,
O,
K_Buffer,
V_Buffer,
qo_indptr,
kv_indptr,
kv_indices,
prefix_lens,
mask_ptr,
mask_indptr,
sink_ptr,
window_start_pos,
sm_scale,
kv_group_num,
stride_qbs,
stride_qh,
stride_obs,
stride_oh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
SLIDING_WINDOW_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
USE_CUSTOM_MASK: tl.constexpr,
HAS_SINK: tl.constexpr,
):
"""
Unified 1-stage kernel for deterministic extend attention.
Both prefix and extend KV are accessed through the unified kv_indices.
"""
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
cur_block_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
# Load sequence information
cur_seq_q_start_idx = tl.load(qo_indptr + cur_seq)
cur_seq_q_len = tl.load(qo_indptr + cur_seq + 1) - cur_seq_q_start_idx
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
cur_seq_kv_len = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
cur_seq_prefix_len = tl.load(prefix_lens + cur_seq)
# Load window start position for sliding window attention
# This is the absolute position of the first key in the window (0 if no sliding window)
cur_window_start = 0
if SLIDING_WINDOW_SIZE > 0:
cur_window_start = tl.load(window_start_pos + cur_seq)
# Load custom mask start index if using custom mask (for speculative decoding)
if USE_CUSTOM_MASK:
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_q_len
mask_d = offs_d < Lq
mask_dv = offs_dv < Lv
# XAI temperature handling
if xai_temperature_len > 0:
offs_qidx = cur_seq_prefix_len + cur_block_m * BLOCK_M + offs_m
xai_temperature_reg = tl.where(
offs_qidx < xai_temperature_len,
1.0,
xai_temperature_len / (offs_qidx + 1.0),
)
# Load Q
offs_q = (
(cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
q = tl.load(Q + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
offs_qpe = (
(cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_dpe[None, :]
)
qpe = tl.load(Q + offs_qpe, mask=mask_m[:, None], other=0.0)
# Initialize accumulators
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
# Unified loop: process all KV tokens (prefix + extend)
for start_n in range(0, cur_seq_kv_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_kv_len
# Compute mask
final_mask = mask_m[:, None] & mask_n[None, :]
# Apply custom mask if provided
if USE_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_kv_len
+ start_n
+ offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]),
other=0,
)
final_mask &= custom_mask
# Apply causal mask for extend part
if IS_CAUSAL and not USE_CUSTOM_MASK:
# Determine if current KV block is in extend region
# Only apply causal mask when both Q and K are in extend region
q_idx = cur_block_m * BLOCK_M + offs_m[:, None]
k_idx_in_total = start_n + offs_n[None, :]
# Causal mask: q_idx >= (k_idx - prefix_len) when k_idx >= prefix_len
# For prefix region (k_idx < prefix_len), no causal mask
k_is_extend = k_idx_in_total >= cur_seq_prefix_len
k_idx_in_extend = k_idx_in_total - cur_seq_prefix_len
causal_mask = tl.where(
k_is_extend,
q_idx >= k_idx_in_extend,
True, # No causal mask for prefix
)
final_mask &= causal_mask
if SLIDING_WINDOW_SIZE > 0:
# Sliding window mask with correct absolute positions
# Q absolute position: window_start + prefix_len + q_position_in_extend
q_abs_pos = (
cur_window_start
+ cur_seq_prefix_len
+ cur_block_m * BLOCK_M
+ offs_m[:, None]
)
# K absolute position: window_start + k_index_in_unified_array
k_abs_pos = cur_window_start + start_n + offs_n[None, :]
# Sliding window: query can attend to keys within window_size
window_mask = q_abs_pos <= (k_abs_pos + SLIDING_WINDOW_SIZE)
final_mask &= window_mask
# Check if we can skip this tile
SKIP_TILE = False
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
if not SKIP_TILE:
# Load KV indices
offs_kv_loc = tl.load(
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
mask=mask_n,
other=0,
)
# Load K
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(mask_n[None, :]) & (mask_d[:, None]),
other=0.0,
)
# Compute QK
qk = tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf"))
# Online softmax
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
# Load V
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=mask_n[:, None] & mask_dv[None, :],
other=0.0,
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
# Handle sink tokens
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
deno += tl.exp(cur_sink - e_max)
# Store output
offs_o = (
(cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_dv[None, :]
)
tl.store(
O + offs_o,
acc / deno[:, None],
mask=mask_m[:, None] & mask_dv[None, :],
)
def extend_attention_fwd_unified(
q,
o,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
prefix_lens,
max_len_extend,
custom_mask=None,
mask_indptr=None,
sm_scale=None,
logit_cap=0.0,
is_causal=True,
sliding_window_size=-1,
sinks=None,
window_start_pos=None,
xai_temperature_len=-1,
):
"""
Unified 1-stage extend attention for deterministic inference.
Args:
q: Query tensor [num_tokens, num_heads, head_dim]
o: Output tensor [num_tokens, num_heads, head_dim]
k_buffer: Key cache buffer
v_buffer: Value cache buffer
qo_indptr: Query offsets [batch_size + 1]
kv_indptr: KV offsets [batch_size + 1] (includes both prefix and extend)
kv_indices: Unified KV indices (both prefix and extend)
prefix_lens: Prefix length for each sequence [batch_size]
max_len_extend: Maximum extend length
custom_mask: Custom attention mask (for speculative decoding tree attention)
mask_indptr: Mask offsets [batch_size + 1]
sm_scale: Softmax scale
logit_cap: Logit capping value
is_causal: Whether to apply causal mask
sliding_window_size: Sliding window size (-1 for no sliding window)
sinks: Sink tokens
window_start_pos: Absolute position of first key in sliding window [batch_size]
(None if sliding window not used)
xai_temperature_len: XAI temperature length
"""
Lq, Lv = q.shape[-1], v_buffer.shape[-1]
# Get block sizes and configuration
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
_get_block_sizes_for_extend_attention(Lq, Lv)
)
sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1]
USE_CUSTOM_MASK = custom_mask is not None
HAS_SINK = sinks is not None
# For sliding window attention, window_start_pos tracks the absolute position
# of the first key in each sequence's window
if sliding_window_size > 0 and window_start_pos is None:
# If not provided, assume window starts at position 0
window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1
extra_kargs = {}
if _is_hip:
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_kernel_unified[grid](
q,
o,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
prefix_lens,
custom_mask,
mask_indptr,
sinks,
window_start_pos,
sm_scale,
kv_group_num,
q.stride(0),
q.stride(1),
o.stride(0),
o.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
Lq=Lq,
Lv=Lv,
IS_CAUSAL=is_causal,
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
HAS_SINK=HAS_SINK,
num_warps=num_warps,
num_stages=num_stages,
**extra_kargs,
)
......@@ -1431,8 +1431,8 @@ class ServerArgs:
f"but you explicitly specified '{self.attention_backend}'."
)
# Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
# Currently, only FA3 and Triton supports radix cache. Support for other backends is in progress
if self.attention_backend not in ["fa3", "triton"]:
self.disable_radix_cache = True
logger.warning(
f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
......
......@@ -424,4 +424,7 @@ if __name__ == "__main__":
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
if args.sampling_seed is None:
args.sampling_seed = 42
test_deterministic(args)
......@@ -10,7 +10,9 @@ from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd_normal,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
build_unified_kv_indices,
extend_attention_fwd,
extend_attention_fwd_unified,
redundant_attention,
)
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
......@@ -571,6 +573,204 @@ class TestTritonAttention(CustomTestCase):
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
def _test_extend_attention_unified_vs_regular_once(self, B, N_CTX, H_Q, H_KV, D):
"""Test that unified kernel produces same results as 2-stage kernel."""
dtype = torch.bfloat16
b_seq_len_prefix = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
# Setup prefix KV indices
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(b_seq_len_prefix.sum().item(),), dtype=torch.int64, device="cuda"
)
for i in range(B):
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
# Setup for extend attention
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
# Run 2-stage kernel
o_regular = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_regular,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=max_len_extend,
)
# Build unified KV indices
extend_kv_indices = torch.arange(
total_token_num - extend_token_num,
total_token_num,
dtype=torch.int64,
device="cuda",
)
extend_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
extend_start_loc[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
unified_kv_indptr, unified_kv_indices, prefix_lens = build_unified_kv_indices(
kv_indptr,
kv_indices,
extend_start_loc,
b_seq_len_extend,
extend_kv_indices,
B,
)
# Run unified kernel
o_unified = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
extend_attention_fwd_unified(
q_extend,
o_unified,
k_buffer,
v_buffer,
qo_indptr,
unified_kv_indptr,
unified_kv_indices,
prefix_lens,
max_len_extend=max_len_extend,
custom_mask=None,
mask_indptr=None,
sm_scale=None,
logit_cap=0.0,
is_causal=True,
)
# Compare results
self.assertTrue(
torch.allclose(o_regular, o_unified, rtol=0.15, atol=0.15),
f"Unified kernel output differs from 2-stage kernel. "
f"Max diff: {(o_regular - o_unified).abs().max()}",
)
def test_extend_attention_unified_vs_regular(self):
"""Test unified kernel matches 2-stage kernel across different configs."""
configs = [
(4, 512, 32, 8, 128), # Standard config
(2, 2048, 32, 8, 128), # Long sequence (test 2048 specifically)
(8, 256, 64, 8, 80), # Non-standard head dim
]
for B, N_CTX, H_Q, H_KV, D in configs:
with self.subTest(B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D):
self._test_extend_attention_unified_vs_regular_once(
B, N_CTX, H_Q, H_KV, D
)
def test_build_unified_kv_indices(self):
"""Test build_unified_kv_indices correctness."""
B = 4
dtype = torch.int64
device = "cuda"
# Setup test data
prefix_lens = torch.tensor([10, 20, 15, 25], dtype=torch.int32, device=device)
extend_lens = torch.tensor([5, 3, 7, 4], dtype=torch.int32, device=device)
# Build prefix indices
prefix_kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
prefix_kv_indptr[1:] = torch.cumsum(prefix_lens, dim=0)
prefix_kv_indices = torch.arange(
prefix_lens.sum().item(), dtype=dtype, device=device
)
# Build extend indices
extend_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
extend_start_loc[1:] = torch.cumsum(extend_lens[:-1], dim=0)
extend_kv_indices = torch.arange(
prefix_lens.sum().item(),
prefix_lens.sum().item() + extend_lens.sum().item(),
dtype=dtype,
device=device,
)
# Build unified indices
unified_kv_indptr, unified_kv_indices, returned_prefix_lens = (
build_unified_kv_indices(
prefix_kv_indptr,
prefix_kv_indices,
extend_start_loc,
extend_lens,
extend_kv_indices,
B,
)
)
# Verify unified_kv_indptr
expected_lens = prefix_lens + extend_lens
expected_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
expected_indptr[1:] = torch.cumsum(expected_lens, dim=0)
self.assertTrue(torch.equal(unified_kv_indptr, expected_indptr))
# Verify prefix_lens
self.assertTrue(torch.equal(returned_prefix_lens, prefix_lens))
# Verify unified_kv_indices structure
for i in range(B):
start_idx = int(unified_kv_indptr[i])
end_idx = int(unified_kv_indptr[i + 1])
prefix_len = int(prefix_lens[i])
extend_len = int(extend_lens[i])
# Check that prefix and extend are concatenated correctly
unified_seq = unified_kv_indices[start_idx:end_idx]
self.assertEqual(len(unified_seq), prefix_len + extend_len)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment