from __future__ import annotations """ end to end attention solution with aiter kernels """ import math import os from dataclasses import dataclass from enum import Enum, auto from functools import partial from typing import TYPE_CHECKING, List, Optional, Union import torch import triton import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInfo try: from aiter import ( flash_attn_varlen_func, mha_batch_prefill_func, paged_attention_ragged, ) from aiter.mla import mla_decode_fwd except ImportError: print( "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) from sglang.srt.configs.model_config import AttentionArch class WrapperDispatch(Enum): SLIDING_WINDOW = auto() CROSS_ATTENTION = auto() @dataclass class ForwardMetadata: kv_indptr: torch.Tensor kv_indices: torch.Tensor qo_indptr: torch.Tensor kv_last_page_len: torch.Tensor max_extend_len: int max_prefix_extend_len: int max_q_len: int max_kv_len: int global_workspace_buffer = None _AITER_PARTITION_SIZE_ROCM = 256 class AiterAttnBackend(AttentionBackend): def __init__( self, model_runner: ModelRunner, skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, ): super().__init__() self.device = model_runner.device self.is_multimodal = model_runner.model_config.is_multimodal self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] self.num_kv_head = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) self.kv_cache_dtype = model_runner.kv_cache_dtype self.req_to_token = model_runner.req_to_token_pool.req_to_token self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA # Parse constants self.max_context_len = model_runner.model_config.context_len self.skip_prefill = skip_prefill max_bs = model_runner.req_to_token_pool.size if kv_indptr_buf is None: self.kv_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) else: self.kv_indptr = kv_indptr_buf self.kv_last_page_len = torch.ones( (max_bs,), dtype=torch.int32, device=model_runner.device ) self.qo_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) # Create prefill indices updater if not skip_prefill: self.indices_updater_prefill = AiterIndicesUpdaterPrefill( model_runner, self ) if self.use_mla: self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill( model_runner, self ) # aiter kernel related initialization self.max_num_partitions = ( self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 ) // _AITER_PARTITION_SIZE_ROCM nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 if not self.use_mla: self.workspace_buffer = torch.empty( (max_bs * self.num_head * self.max_num_partitions * self.head_dim) * nbyes_per_qo_elem + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, dtype=torch.uint8, device=self.device, ) self.scale = float(1.0 / (self.head_dim**0.5)) self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( self.device ) self.logits_soft_cap = 0.0 self.forward_metadata: ForwardMetadata = None if self.use_mla: self.qo_indptr_ = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" bs = forward_batch.batch_size kv_indptr = self.kv_indptr spec_info = forward_batch.spec_info qo_indptr = None kv_last_page_len = None max_extend_len = None if forward_batch.forward_mode.is_decode_or_idle(): if spec_info is None: kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.zeros( forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, forward_batch.req_pool_indices, forward_batch.seq_lens, kv_indptr, None, kv_indices, self.req_to_token.stride(0), ) else: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 if self.use_mla: qo_indptr = self.qo_indptr_[: bs + 1] qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) kv_last_page_len = self.kv_last_page_len[:bs] max_extend_len = 1 self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, qo_indptr, kv_last_page_len, max_extend_len, None, None, None, ) elif forward_batch.forward_mode.is_draft_extend(): if self.use_mla: prefix_lens = forward_batch.extend_prefix_lens self.mla_indices_updater_prefill.update( forward_batch.req_pool_indices, prefix_lens, prefix_lens.sum().item(), forward_batch.extend_seq_lens, encoder_lens=forward_batch.encoder_lens, spec_info=None, ) self.forward_metadata = ForwardMetadata( self.mla_indices_updater_prefill.kv_indptr, self.mla_indices_updater_prefill.kv_indices, self.mla_indices_updater_prefill.qo_indptr, self.mla_indices_updater_prefill.kv_last_page_len, self.mla_indices_updater_prefill.max_extend_len, self.mla_indices_updater_prefill.max_prefix_extend_len, None, None, ) else: self.indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, prefix_lens=None, encoder_lens=forward_batch.encoder_lens, spec_info=forward_batch.spec_info, ) self.forward_metadata = ForwardMetadata( self.indices_updater_prefill.kv_indptr, self.indices_updater_prefill.kv_indices, None, None, None, None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, ) elif forward_batch.forward_mode.is_target_verify(): if self.use_mla: prefix_lens = forward_batch.extend_prefix_lens self.mla_indices_updater_prefill.update( forward_batch.req_pool_indices, prefix_lens, prefix_lens.sum().item(), forward_batch.extend_seq_lens, encoder_lens=forward_batch.encoder_lens, spec_info=None, ) self.forward_metadata = ForwardMetadata( self.mla_indices_updater_prefill.kv_indptr, self.mla_indices_updater_prefill.kv_indices, self.mla_indices_updater_prefill.qo_indptr, self.mla_indices_updater_prefill.kv_last_page_len, self.mla_indices_updater_prefill.max_extend_len, self.mla_indices_updater_prefill.max_prefix_extend_len, None, None, ) else: self.indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, prefix_lens=None, encoder_lens=forward_batch.encoder_lens, spec_info=forward_batch.spec_info, ) self.forward_metadata = ForwardMetadata( self.indices_updater_prefill.kv_indptr, self.indices_updater_prefill.kv_indices, None, None, None, None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, ) else: prefix_lens = forward_batch.extend_prefix_lens if self.is_multimodal: extend_no_prefix = False else: extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) if self.use_mla: self.mla_indices_updater_prefill.update( forward_batch.req_pool_indices, prefix_lens, prefix_lens.sum().item(), forward_batch.extend_seq_lens, encoder_lens=forward_batch.encoder_lens, spec_info=None, ) self.forward_metadata = ForwardMetadata( self.mla_indices_updater_prefill.kv_indptr, self.mla_indices_updater_prefill.kv_indices, self.mla_indices_updater_prefill.qo_indptr, self.mla_indices_updater_prefill.kv_last_page_len, self.mla_indices_updater_prefill.max_extend_len, self.mla_indices_updater_prefill.max_prefix_extend_len, None, None, ) else: self.indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, prefix_lens, encoder_lens=forward_batch.encoder_lens, spec_info=None, ) self.forward_metadata = ForwardMetadata( self.indices_updater_prefill.kv_indptr, self.indices_updater_prefill.kv_indices, None, None, None, None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, ) def init_cuda_graph_state( self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None ): self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( (max_bs * self.max_context_len), dtype=torch.int32, device=self.device, ) else: self.cuda_graph_kv_indices = kv_indices_buf if not self.skip_prefill: self.cuda_graph_custom_mask = torch.zeros( (max_bs * self.max_context_len), dtype=torch.uint8, device=self.device, ) def init_forward_metadata_capture_cuda_graph( self, bs: int, num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[SpecInfo], ): if forward_mode.is_decode_or_idle(): qo_indptr = None kv_last_page_len = None max_extend_len = None if spec_info is None: kv_indptr = self.kv_indptr kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = self.cuda_graph_kv_indices create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, seq_lens, kv_indptr, None, kv_indices, self.req_to_token.stride(0), ) else: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices if self.use_mla: qo_indptr = self.qo_indptr_[: bs + 1] qo_indptr[1 : bs + 1] = torch.cumsum( self.cuda_graph_kv_last_page_len[:bs], dim=0 ) max_extend_len = 1 kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, qo_indptr, kv_last_page_len, max_extend_len, None, None, None, ) elif forward_mode.is_target_verify(): if self.use_mla: qo_indptr = self.qo_indptr[: bs + 1] qo_indptr[: bs + 1] = torch.arange( 0, (1 + bs) * self.num_draft_tokens, step=self.num_draft_tokens, dtype=torch.int32, device=self.device, ) kv_indptr = self.kv_indptr[: bs + 1] kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) kv_indices = self.cuda_graph_kv_indices create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, seq_lens, kv_indptr, None, kv_indices, self.req_to_token.stride(0), ) max_extend_len = self.num_draft_tokens kv_last_page_len = None self.forward_metadata = ForwardMetadata( kv_indptr, kv_indices, qo_indptr, kv_last_page_len, max_extend_len, None, None, None, ) else: seq_lens_sum = seq_lens.sum().item() self.indices_updater_prefill.update( req_pool_indices, seq_lens, seq_lens_sum, prefix_lens=None, encoder_lens=encoder_lens, spec_info=spec_info, ) self.forward_metadata = ForwardMetadata( self.indices_updater_prefill.kv_indptr, self.indices_updater_prefill.kv_indices, None, None, None, None, self.indices_updater_prefill.max_q_len, self.indices_updater_prefill.max_kv_len, ) else: raise ValueError(f"Invalid mode: {forward_mode=}") def init_forward_metadata_replay_cuda_graph( self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[SpecInfo], seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): kv_indptr = self.kv_indptr kv_indices = self.cuda_graph_kv_indices if spec_info is None: kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) kv_indptr = kv_indptr[: bs + 1] create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices[:bs], seq_lens[:bs], kv_indptr, None, kv_indices, self.req_to_token.stride(0), ) else: kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices elif forward_mode.is_target_verify(): self.indices_updater_prefill.update( req_pool_indices[:bs], seq_lens[:bs], seq_lens_sum, prefix_lens=None, encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, spec_info=spec_info, ) else: raise ValueError("Invalid forward mode") def get_cuda_graph_seq_len_fill_value(self): return 1 def forward_extend( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, ): cache_loc = ( forward_batch.out_cache_loc if not layer.is_cross_attention else forward_batch.encoder_out_cache_loc ) self.logits_soft_cap = layer.logit_cap if k is not None: assert v is not None if save_kv_cache: if self.use_mla: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) else: forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) if self.use_mla: max_extend_len = self.forward_metadata.max_extend_len max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len kv_indptr = self.forward_metadata.kv_indptr kv_indices = self.forward_metadata.kv_indices kv_last_page_lens = self.forward_metadata.kv_last_page_len qo_indptr = self.forward_metadata.qo_indptr K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) kv_lora_rank = V_Buffer.shape[-1] qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim assert len(q.shape) == 3 assert len(k.shape) == 3 assert len(v.shape) == 3 if kv_indices.shape[0] == 0: o = flash_attn_varlen_func( q, k, v, qo_indptr, qo_indptr, max_extend_len, max_extend_len, softmax_scale=layer.scaling, causal=True, ) return o elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) kvc, k_pe = torch.split( K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 ) kvprefix = layer.kv_b_proj(kvc.contiguous())[0] kvprefix = kvprefix.view( -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim ) k_prefix, v_prefix = torch.split( kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 ) k_prefix = torch.cat( [ k_prefix, torch.broadcast_to( k_pe, (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), ), ], dim=-1, ) assert ( forward_batch.extend_prefix_lens.shape == forward_batch.extend_seq_lens.shape ) k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu) k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu) assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu) k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el]) v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu) v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu) v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el]) o = flash_attn_varlen_func( q, k, v, qo_indptr, kv_indptr, max_extend_len, max_prefix_extend_len, softmax_scale=layer.scaling, causal=True, ) return o else: k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) bs0 = forward_batch.batch_size + 1 o = mha_batch_prefill_func( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), k_cache, v_cache, self.qo_indptr[:bs0], self.forward_metadata.kv_indptr[:bs0], self.forward_metadata.kv_indices, self.forward_metadata.max_q_len, self.forward_metadata.max_kv_len, causal=True, logits_soft_cap=self.logits_soft_cap, alibi_slopes=None, return_lse=False, return_attn_probs=False, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, ): q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: o = torch.empty_like(q) if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) if self.use_mla: k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) mla_decode_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_buffer.view(-1, 1, 1, layer.qk_head_dim), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), self.forward_metadata.qo_indptr, self.forward_metadata.kv_indptr, self.forward_metadata.kv_indices, self.forward_metadata.kv_last_page_len, self.forward_metadata.max_extend_len, layer.scaling, layer.logit_cap, ) k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) else: self.logits_soft_cap = layer.logit_cap paged_attention_ragged( o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), self.workspace_buffer, q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( -1, 1, layer.tp_k_head_num, layer.qk_head_dim ), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( -1, 1, layer.tp_v_head_num, layer.v_head_dim ), self.scale, self.forward_metadata.kv_indptr, self.forward_metadata.kv_indices, self.kv_last_page_len, 1, self.max_num_partitions, None, "auto", "NHD", self.logits_soft_cap, self.k_scale, self.v_scale, None, _AITER_PARTITION_SIZE_ROCM, ) return o class AiterIndicesUpdaterPrefill: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): # Parse Constants self.num_qo_heads = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.num_kv_heads = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size self.attn_backend = attn_backend # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len self.qo_indptr = attn_backend.qo_indptr self.req_to_token = model_runner.req_to_token_pool.req_to_token self.update = self.update_single_wrapper # get the last index of the pool self.pool_size = ( model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size ) - 1 self.kv_indices = None self.max_q_len = 0 self.max_kv_len = 0 def update( self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( self, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], ): kv_start_idx = None kv_indptr = self.kv_indptr qo_indptr = self.qo_indptr paged_kernel_lens = seq_lens paged_kernel_lens_sum = seq_lens_sum bs = len(req_pool_indices) if spec_info is None: # Normal extend kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] # (TODO: Kk) WA - CI test_moe_eval_accuracy_large.py # mha_batch_prefill reads 128 data to do computatoin # if real data is not long enough then original padding value 0 is used # but the 0 location will be made nan (noqa) in cuda graph capture mode # this will cause the output tensor value becomes nan # WA is to assure that last index of pool not changed kv_indices = torch.full( (paged_kernel_lens_sum + 128,), self.pool_size, dtype=torch.int32, device=req_pool_indices.device, ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx, kv_indices, self.req_to_token.shape[1], ) self.max_kv_len = torch.max(paged_kernel_lens).item() extend_lens = seq_lens - prefix_lens self.max_q_len = torch.max(extend_lens).item() qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] custom_mask = None else: kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( req_pool_indices, paged_kernel_lens, paged_kernel_lens_sum, self.req_to_token, ) ) self.kv_indices = kv_indices class AiterMlaIndicesUpdaterPrefill: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): # Parse Constants self.attn_backend = attn_backend # Buffers and wrappers self.req_to_token = model_runner.req_to_token_pool.req_to_token self.update = self.update_single_wrapper self.kv_indptr = None self.kv_indices = None self.qo_indptr = None self.kv_last_page_len = None self.max_extend_len = 0 self.max_prefix_extend_len = 0 def update( self, req_pool_indices: torch.Tensor, prefix_lens: torch.Tensor, prefix_lens_sum: int, extend_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def update_single_wrapper( self, req_pool_indices: torch.Tensor, prefix_lens: torch.Tensor, prefix_lens_sum: int, extend_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], ): paged_kernel_lens = prefix_lens paged_kernel_lens_sum = prefix_lens_sum bs = len(req_pool_indices) kv_indptr = self.attn_backend.kv_indptr if spec_info is None: # Normal extend kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( paged_kernel_lens_sum, dtype=torch.int32, device=req_pool_indices.device, ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, paged_kernel_lens, kv_indptr, None, kv_indices, self.req_to_token.stride(0), ) qo_indptr = self.attn_backend.qo_indptr qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] max_extend_len = torch.max(extend_lens).item() max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item() kv_indptr += qo_indptr else: kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( req_pool_indices, paged_kernel_lens, paged_kernel_lens_sum, self.req_to_token, ) ) self.kv_indptr = kv_indptr self.kv_indices = kv_indices self.qo_indptr = qo_indptr self.max_extend_len = max_extend_len self.max_prefix_extend_len = max_prefix_extend_len