Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
if TYPE_CHECKING:
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
......@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
def support_triton(self):
"""Check if the current backend supports triton."""
return True
def get_indexer_metadata(
self,
layer_id: int,
forward_batch: ForwardBatch,
) -> Optional[BaseIndexerMetadata]:
"""Get the indexer metadata. None means don't support indexer."""
return None
......@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
and self.fa_impl_ver != 4
):
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
......
......@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.utils import (
get_int_env_var,
is_flashinfer_available,
......@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
......@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
......@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
......@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
......@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
......@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
......@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
......@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
):
# Keep the signature for type checking. It will be assigned during runtime.
......@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
):
if use_ragged:
......@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
......@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
......@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
):
......@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask = None
else:
assert isinstance(
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
......
......@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> Optional[BaseIndexerMetadata]:
backend = self._select_backend(forward_batch.forward_mode)
return backend.get_indexer_metadata(layer_id, forward_batch)
......@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
self.rotary_emb = rotary_emb
self.layer_id = layer_id
self.has_preprocess_weights = False
self.dtype = None
self.q_lora_rank = self.q_b_proj.input_size # 1536
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
self.num_local_heads = num_local_heads # tp
self.qk_nope_head_dim = qk_nope_head_dim # 128
self.qk_rope_head_dim = qk_rope_head_dim # 64
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
def preprocess_weights(self, hidden_states):
self.dummy = torch.empty(
......@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
return k_cache, v_cache, slot_mapping
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
def forward_absorb_prepare_npu_rms_norm_cache(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch,
zero_allocator,
):
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
self.dtype = hidden_states.dtype
self.cos, self.sin = self.get_sin_cos(positions)
self.kvCache, self.kvCacheRope, self.slotmapping = (
self.get_kv_cache_and_cache_idx(forward_batch)
)
if not self.has_preprocess_weights:
self.has_preprocess_weights = True
cos, sin = self.cos, self.sin
if self.q_lora_rank is not None:
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q_lowrank)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # b*s,n,d
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
latent_cache = latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
) # (B*S,N,1,D)
cache_mode = "PA_BNSD"
self.kvCache = self.kvCache.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.kv_lora_rank,
)
self.kvCacheRope = self.kvCacheRope.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.qk_rope_head_dim,
)
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
latent_cache,
self.kv_a_layernorm.weight,
cos,
sin,
self.slotmapping.to(torch.int64),
self.kvCacheRope,
self.kvCache,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
input_dtype = hidden_states.dtype
if not self.has_preprocess_weights:
self.preprocess_weights(hidden_states)
......@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
zero_allocator,
positions,
)
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
_is_w8a8 = (
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
and self.qkv_a_proj.quant_method.quantization_config.get_name()
== "w8a8_int8"
)
if _is_w8a8:
return self.forward_mlapo(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb_prepare_npu_rms_norm_cache(
positions, hidden_states, forward_batch, zero_allocator
)
from .topk import fast_topk, fast_topk_transform
__all__ = ["fast_topk", "fast_topk_transform"]
This diff is collapsed.
from __future__ import annotations
from typing import Any
import torch
from .utils import load_kernel_module
def _load_topk_module() -> Any:
"""
Load the index manipulation module.
"""
return load_kernel_module("topk.cu", "topk_kernel")
# TODO(dark): configure out why my cuda impl is a little slower....
# I believe it has something to do with unrolling loops (?)
_USE_TL = True
def fast_topk(
score: torch.Tensor,
indices: torch.Tensor,
lengths: torch.Tensor,
) -> torch.Tensor:
return _load_topk_module().fast_topk(score, indices, lengths, _USE_TL)
def fast_topk_transform(
score: torch.Tensor,
lengths: torch.Tensor,
dst_page_table: torch.Tensor,
src_page_table: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return _load_topk_module().fast_topk_transform(
score, lengths, dst_page_table, src_page_table, cu_seqlens, _USE_TL
)
from __future__ import annotations
import os
from functools import lru_cache
from typing import Any, Iterable
@lru_cache()
def _prepare_for_load() -> str:
import os
import warnings
warnings.filterwarnings(
"ignore", category=UserWarning, module="torch.utils.cpp_extension"
)
return os.path.dirname(os.path.abspath(__file__))
@lru_cache()
def load_kernel_module(
path: str | Iterable[str],
name: str,
*,
build: str = "build",
cflags: Iterable[str] | None = None,
cuda_flags: Iterable[str] | None = None,
ldflags: Iterable[str] | None = None,
) -> Any:
from torch.utils.cpp_extension import load
if isinstance(path, str):
path = (path,)
abs_path = _prepare_for_load()
build_dir = f"{abs_path}/{build}"
os.makedirs(build_dir, exist_ok=True)
return load(
name=name,
sources=[f"{abs_path}/csrc/{p}" for p in path],
extra_cflags=list(cflags or []) or ["-O3", "-std=c++17"],
extra_cuda_cflags=list(cuda_flags or []) or ["-O3", "-std=c++17"],
extra_ldflags=list(ldflags or []) or None,
build_directory=build_dir,
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment