from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional import torch import torch_npu from torch.nn.functional import scaled_dot_product_attention from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import get_bool_env_var if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner import os import numpy as np @dataclass class ForwardMetadata: # calculated map for kv positions [bs * maxseqlen] block_tables: Optional[torch.Tensor] = None # seq len inputs extend_seq_lens_cpu_int: Optional[torch.Tensor] = None seq_lens_cpu_int: Optional[torch.Tensor] = None seq_lens_cpu_list: Optional[List[int]] = None seq_lens_list_cumsum: Optional[List[int]] = None class AscendAttnBackend(AttentionBackend): def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16): mask_flag = torch.tril( torch.ones((max_seq_len, max_seq_len), dtype=torch.bool) ).view(max_seq_len, max_seq_len) mask_flag = ~mask_flag if dtype == torch.float16: mask_value = torch.finfo(torch.float32).min else: mask_value = 1 self.mask = ( torch.masked_fill( torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value ) .to(dtype) .to(self.device) ) self.mask_len = max_seq_len def __init__(self, model_runner: ModelRunner): super().__init__() self.forward_metadata = None self.device = model_runner.device self.page_size = model_runner.page_size self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA if self.use_mla: self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.native_attn = TorchNativeAttnBackend(model_runner) self.graph_metadata = {} self.max_context_len = model_runner.model_config.context_len self.req_to_token = model_runner.req_to_token_pool.req_to_token self.graph_mode = False self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False") if not self.use_fia: self.gen_attention_mask(128, model_runner.dtype) mask_length = 2048 self.fia_mask = ~torch.tril( torch.ones( (mask_length, mask_length), dtype=torch.bool, device=model_runner.device, ) ) def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" tp_size = get_attention_tp_size() self.forward_metadata = ForwardMetadata() self.forward_metadata.block_tables = ( forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : forward_batch.seq_lens.max() ][:, :: self.page_size] // self.page_size ) if forward_batch.extend_seq_lens is not None: self.forward_metadata.extend_seq_lens_cpu_int = ( forward_batch.extend_seq_lens.cpu().int() ) self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu) if forward_batch.is_extend_in_batch: seq_lens_list_cumsum[-1] = ( (seq_lens_list_cumsum[-1] - 1) // tp_size + 1 ) * tp_size self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum self.graph_mode = False def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.graph_metadata = { "block_tables": torch.empty( (max_bs, self.max_context_len // self.page_size), dtype=torch.int32, device=self.device, ), } def init_forward_metadata_capture_cuda_graph( self, bs: int, num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): metadata = ForwardMetadata() metadata.block_tables = self.graph_metadata["block_tables"][:bs, :] metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist() self.graph_metadata[bs] = metadata self.forward_metadata = metadata self.graph_mode = True def init_forward_metadata_replay_cuda_graph( self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], ): metadata = self.graph_metadata[bs] max_len = seq_lens_cpu[:bs].max().item() max_seq_pages = (max_len + self.page_size - 1) // self.page_size metadata.block_tables[:bs, :max_seq_pages].copy_( self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size] // self.page_size ) metadata.block_tables[:bs, max_seq_pages:].fill_(0) metadata.block_tables[bs:, :].fill_(0) self.forward_metadata = metadata self.graph_mode = True def get_cuda_graph_seq_len_fill_value(self): return 0 def forward_extend( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, ): if not self.use_mla: if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) if self.use_fia: """FIA will support multi-bs in the later version of CANN""" q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) attn_output = torch.empty( (q.size(0), layer.tp_q_head_num, layer.v_head_dim), device=q.device, dtype=q.dtype, ) q_len_offset = 0 for q_len in forward_batch.extend_seq_lens_cpu: attn_output[q_len_offset : q_len_offset + q_len] = ( torch.ops.npu.npu_fused_infer_attention_score( q[None, q_len_offset : q_len_offset + q_len], k[None, q_len_offset : q_len_offset + q_len], v[None, q_len_offset : q_len_offset + q_len], num_heads=layer.tp_q_head_num, num_key_value_heads=layer.tp_k_head_num, input_layout="BSND", # todo, TND not supports q_heads!=k_heads atten_mask=self.fia_mask.unsqueeze(0), sparse_mode=3, scale=layer.scaling, next_tokens=0, )[0] ) q_len_offset += q_len attn_output = attn_output.view( -1, layer.tp_q_head_num * layer.v_head_dim ) else: if layer.qk_head_dim <= 128: query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) attn_output = torch.empty( (query.shape[0], layer.tp_q_head_num * layer.v_head_dim), dtype=query.dtype, device=query.device, ) torch_npu._npu_flash_attention_qlens( query=query, key_cache=k_cache, value_cache=v_cache, mask=self.mask, block_table=self.forward_metadata.block_tables, seq_len=self.forward_metadata.extend_seq_lens_cpu_int, context_lens=self.forward_metadata.seq_lens_cpu_int, scale_value=layer.scaling, num_heads=layer.tp_q_head_num, num_kv_heads=layer.tp_k_head_num, out=attn_output, ) else: if layer.qk_head_dim != layer.v_head_dim: attn_output = q.new_empty( (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) ) else: attn_output = torch.empty_like(q) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim) causal = True if ( layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY ): causal = False self.native_attn._run_sdpa_forward_extend( q_, o_, k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim), v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim), forward_batch.req_to_token_pool.req_to_token, forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.extend_prefix_lens, forward_batch.extend_seq_lens, scaling=layer.scaling, enable_gqa=use_gqa, causal=causal, ) else: assert ( layer.qk_head_dim != layer.v_head_dim ), "FIA only supports qk_head_dim != v_head_dim" q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( q_nope, k_nope, v, query_rope=q_rope, key_rope=k_rope, num_heads=layer.tp_q_head_num, input_layout="TND", atten_mask=self.fia_mask, sparse_mode=3, actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum, actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum, scale=layer.scaling, next_tokens=0, ) return attn_output def forward_decode_graph( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, ): if save_kv_cache: if self.use_mla: k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank) k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim) forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, k_rope ) else: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) if not self.use_mla: k_cache = forward_batch.token_to_kv_pool.get_key_buffer( layer.layer_id ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) v_cache = forward_batch.token_to_kv_pool.get_value_buffer( layer.layer_id ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) if self.forward_metadata.seq_lens_cpu_int is None: actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list else: actual_seq_len_kv = ( self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist() ) num_tokens = query.shape[0] workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( query, k_cache, v_cache, block_table=self.forward_metadata.block_tables, block_size=self.page_size, num_heads=layer.tp_q_head_num, num_key_value_heads=layer.tp_k_head_num, input_layout="BSH", scale=layer.scaling, actual_seq_lengths_kv=actual_seq_len_kv, ) output = torch.empty( (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), dtype=q.dtype, device=q.device, ) softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) torch_npu.npu_fused_infer_attention_score.out( query, k_cache, v_cache, block_table=self.forward_metadata.block_tables, block_size=self.page_size, num_heads=layer.tp_q_head_num, num_key_value_heads=layer.tp_k_head_num, input_layout="BSH", scale=layer.scaling, actual_seq_lengths_kv=actual_seq_len_kv, workspace=workspace, out=[output, softmax_lse], ) return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) else: c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) k_rope_cache = k_rope.view( -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim ) c_kv_cache = c_kv.view( -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank ) q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous() q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim) if self.forward_metadata.seq_lens_cpu_int is None: actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list else: actual_seq_len_kv = ( self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist() ) workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, c_kv_cache, c_kv_cache, query_rope=q_rope, key_rope=k_rope_cache, num_heads=layer.tp_q_head_num, num_key_value_heads=layer.tp_k_head_num, block_table=self.forward_metadata.block_tables, block_size=self.page_size, input_layout="BNSD", scale=layer.scaling, actual_seq_lengths_kv=actual_seq_len_kv, antiquant_mode=0, antiquant_scale=None, sparse_mode=0, ) output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device) softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) torch_npu.npu_fused_infer_attention_score.out( q_nope, c_kv_cache, c_kv_cache, query_rope=q_rope, key_rope=k_rope_cache, num_heads=layer.tp_q_head_num, num_key_value_heads=layer.tp_k_head_num, block_table=self.forward_metadata.block_tables, block_size=self.page_size, input_layout="BNSD", scale=layer.scaling, actual_seq_lengths_kv=actual_seq_len_kv, antiquant_mode=0, antiquant_scale=None, sparse_mode=0, workspace=workspace, out=[output, softmax_lse], ) return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank) def forward_decode( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, ): if self.graph_mode: return self.forward_decode_graph( q, k, v, layer, forward_batch, save_kv_cache, q_rope=q_rope, k_rope=k_rope, ) if not self.use_mla: if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) num_tokens = q.shape[0] k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) if self.use_fia: attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( q.view( forward_batch.batch_size, -1, layer.tp_q_head_num, layer.qk_head_dim, ), k_cache.view( -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim ), v_cache.view( -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim ), num_heads=layer.tp_q_head_num, num_key_value_heads=layer.tp_k_head_num, input_layout="BSND", atten_mask=None, block_size=self.page_size, block_table=self.forward_metadata.block_tables, actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int, scale=layer.scaling, ) else: query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) num_tokens = query.shape[0] attn_output = torch.empty( (num_tokens, layer.tp_q_head_num, layer.v_head_dim), dtype=query.dtype, device=query.device, ) torch_npu._npu_paged_attention( query=query, key_cache=k_cache, value_cache=v_cache, num_heads=layer.tp_q_head_num, num_kv_heads=layer.tp_k_head_num, scale_value=layer.scaling, block_table=self.forward_metadata.block_tables, context_lens=self.forward_metadata.seq_lens_cpu_int, out=attn_output, ) return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) else: if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, k_rope ) num_tokens = q.shape[0] kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8: """layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN""" kv_c = kv_c.view( -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank ) k_pe = k_pe.view( -1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim ) q = q.view( forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank ) q_rope = q_rope.view( forward_batch.batch_size, -1, layer.tp_q_head_num, self.qk_rope_head_dim, ) attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score( q, kv_c, kv_c, query_rope=q_rope, key_rope=k_pe, num_heads=layer.tp_q_head_num, num_key_value_heads=layer.tp_k_head_num, input_layout="BSND", atten_mask=None, sparse_mode=0, scale=layer.scaling, antiquant_mode=0, antiquant_scale=None, block_table=self.forward_metadata.block_tables, block_size=self.page_size, actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int, ) else: assert ( self.graph_mode == False ) # _npu_paged_attention_mla not support graph mode q = torch.cat([q, q_rope], dim=-1) query = q.view(-1, layer.tp_q_head_num, layer.head_dim) kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1) kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( -1, self.page_size, layer.tp_k_head_num, self.kv_lora_rank + self.qk_rope_head_dim, ) attn_output = torch.empty( [num_tokens, layer.tp_q_head_num, self.kv_lora_rank], dtype=q.dtype, device=q.device, ) torch_npu._npu_paged_attention_mla( query=query, key_cache=kv_c_and_k_pe_cache, num_kv_heads=layer.tp_k_head_num, num_heads=layer.tp_q_head_num, scale_value=layer.scaling, block_table=self.forward_metadata.block_tables, context_lens=self.forward_metadata.seq_lens_cpu_int, mla_vheadsize=self.kv_lora_rank, out=attn_output, ) return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)