Commit 7993ed8d authored by maxiao1's avatar maxiao1
Browse files

适配deepseekv3.2

parent 443a1b4a
import math
from typing import Optional, Tuple, List
import torch
def cdiv(x: int, y: int):
return (x+y-1) // y
def native_mla_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
s_q, _, d_qk = q.size()
s_kv = kv.size(0)
topk = indices.size(-1)
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
indices = indices[:, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
qs = q.float() # [s_q, h_q, d_qk]
kvs = kv[ :, 0, :].float() # [s_kv, d_qk]
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(s_q, topk, d_qk) # [s_q, topk, d_qk]
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
attn_score *= sm_scale * math.log2(math.e)
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score @ kvs[:, :, :d_v]
return (max_logits, lse, result)
def native_mla_with_kvcache(
q: torch.Tensor, # [batch_size, s_q, h_q, d]
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
block_table: torch.Tensor, # [batch_size, ?]
cache_seqlens: torch.Tensor, # [batch_size]
dv: int,
is_causal: bool,
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation in PyTorch
"""
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
for i in range(s_q):
cur_indices = indices[i]
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention(
batch_idx: int,
query: torch.Tensor, # [h_q, s_q, d]
kv: torch.Tensor, # [h_kv, s_k, d]
dv: int,
is_causal,
indices: Optional[torch.Tensor], # [s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
h_q = query.size(0)
h_kv = kv.size(0)
s_q = query.shape[-2]
s_k = kv.shape[-2]
query = query.float()
kv = kv.float()
if h_kv != 1:
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None:
mask = torch.ones(s_q, s_k, dtype=torch.bool)
if is_causal:
assert indices is None
mask = mask.tril(diagonal=s_k - s_q)
if indices is not None:
mask &= get_topk_attn_mask(s_q, s_k, indices)
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype)
attn_weight /= math.sqrt(query.size(-1))
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output, lse
b, s_q, h_q, d = q.size()
block_size = blocked_k.size(1)
h_kv = blocked_k.size(2)
cache_seqlens_cpu = cache_seqlens.cpu()
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0: cur_num_blocks]
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
cur_out, cur_lse = scaled_dot_product_attention(
i,
q[i].transpose(0, 1),
cur_kv.transpose(0, 1),
dv,
is_causal,
indices[i] if indices is not None else None
)
out_ref[i] = cur_out.transpose(0, 1)
lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16)
return out_ref, lse_ref
# fallback_fp8.py
# PyTorch fallback implementation for DeepGEMM-like fp8 logits ops
from sglang.srt.utils import ceil_div
import torch
@torch.no_grad()
def fallback_fp8_mqa_logits(q: torch.Tensor,
kv: torch.Tensor,
weights: torch.Tensor,
ks: torch.Tensor,
ke: torch.Tensor, cost_only: bool = False) -> torch.Tensor:
seq_len_kv = kv.shape[0]
if cost_only:
start = ks.clamp(min=0, max=seq_len_kv)
end = ke.clamp(min=0, max=seq_len_kv)
count_ones_per_row = (end - start).clamp(min=0)
return count_ones_per_row.sum()
k = kv
q = q.float()
k = k.float()
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < ke[:, None]
mask = mask_lo & mask_hi
score = torch.einsum('mhd,nd->hmn', q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float('-inf'))
#cost = mask.sum()
return logits
# """
# PyTorch fallback for fp8_mqa_logits.
# No real fp8 used, just FP32.
# Args:
# q: (M, H, D) query
# k: (N, D) key
# weights: (M, H)
# ks: (M,) int32
# ke: (M,) int32
# Returns:
# logits: (M, N) with -inf outside of valid range
# """
# M, H, D = q.shape
# N = k[0].shape[0]
# logits = torch.full((M, N), float("-inf"), dtype=torch.float32, device=q.device)
# # for i in range(M):
# # start = max(ks[i].item(), 0)
# # end = min(ke[i].item(), N)
# # if start >= end:
# # continue
# # qi = q[i] # (H, D)
# # ki = k[start:end] # (L, D)
# # sim = torch.matmul(qi, ki.T) # (H, L)
# # weighted_sim = (sim.relu() * weights[i].unsqueeze(-1)).sum(dim=0) # (L,)
# # logits[i, start:end] = weighted_sim
# return logits
@torch.no_grad()
def fallback_fp8_paged_mqa_logits(q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int) -> torch.Tensor:
batch_size, next_n, heads, dim = q.size()
num_block, block_size, _, dim = kv_cache.size()
logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32)
context_lens = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens[i]
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
for block_rk in range(ceil_div(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device=q.device)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None])
s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf'))
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf'))
return logits
"""
PyTorch fallback for fp8_paged_mqa_logits.
No real fp8 used, just FP32.
Args:
q: (B, N, H, D)
kv_cache: (num_blocks, block_size, 1, D)
weights: (B * N, H)
context_lens: (B,)
block_tables: (B, max_blocks)
max_model_len: int
Returns:
logits: (B * N, max_model_len)
"""
B, N, H, D = q.shape
block_size = kv_cache.shape[1]
logits = torch.full((B * N, max_model_len), float("-inf"), dtype=torch.float32, device=q.device)
for i in range(B):
ctx_len = context_lens[i].item()
q_offsets = torch.arange(ctx_len - N, ctx_len, device=q.device)
weight_slice = weights[i * N:(i + 1) * N, :].transpose(0, 1).contiguous()
for br in range((ctx_len + block_size - 1) // block_size):
blk_idx = block_tables[i, br].item()
if blk_idx < 0:
continue
qx = q[i] # (N, H, D)
kx = kv_cache[blk_idx] # (block_size, 1, D)
kx = kx.squeeze(1) # (block_size, D)
k_offsets = torch.arange(br * block_size, (br + 1) * block_size, device=q.device)
mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None]) # (N, block_size)
s = torch.where(mask[None, :, :],
torch.einsum('nhd,ld->hnl', qx, kx),
torch.full((H, N, block_size), float("-inf"), device=q.device))
s = s.relu() * weight_slice[..., None]
logits_slice = s.sum(dim=0) # (N, block_size)
mask_block = (k_offsets[None, :] <= q_offsets[:, None])
logits[i * N:(i + 1) * N, br * block_size:(br + 1) * block_size] = \
torch.where(mask_block, logits_slice, float("-inf"))
return logits
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from sglang.srt.layers.attention.nsa.fallback_fp8 import fallback_fp8_mqa_logits, fallback_fp8_paged_mqa_logits
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
...@@ -14,7 +15,7 @@ from sglang.srt.utils import add_prefix, is_npu ...@@ -14,7 +15,7 @@ from sglang.srt.utils import add_prefix, is_npu
if not is_npu(): if not is_npu():
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
import deep_gemm #import deep_gemm
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.dp_attention import get_attention_tp_group
...@@ -27,14 +28,14 @@ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode ...@@ -27,14 +28,14 @@ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import add_prefix, align, is_cuda from sglang.srt.utils import add_prefix, align, is_cuda
try: # try:
import deep_gemm_v32 # import deep_gemm_v32
except ImportError as e: # except ImportError as e:
print("Error when importing deep_gemm_v32, try deep_gemm") # print("Error when importing deep_gemm_v32, try deep_gemm")
try: # try:
import deep_gemm as deep_gemm_v32 # import deep_gemm as deep_gemm_v32
except ImportError as e: # except ImportError as e:
print("Error when importing deep_gemm, skip") # print("Error when importing deep_gemm, skip")
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -81,16 +82,47 @@ class BaseIndexerMetadata(ABC): ...@@ -81,16 +82,47 @@ class BaseIndexerMetadata(ABC):
Don't assume it is the topk indices of the input logits. Don't assume it is the topk indices of the input logits.
""" """
def hadamard_transform_pytorch(x: torch.Tensor, scale: float) -> torch.Tensor:
"""
A native PyTorch implementation of the Fast Hadamard Transform that mimics
the behavior of the custom CUDA kernel's call signature.
Args:
x (torch.Tensor): Input tensor of shape (*, N), where N is a power of 2.
scale (float): The normalization factor to multiply the result by.
Returns:
torch.Tensor: The Hadamard transformed tensor.
"""
# Base case for recursion
if x.shape[-1] == 1:
return x
# Split the tensor into two halves
half_size = x.shape[-1] // 2
a = x[..., :half_size]
b = x[..., half_size:]
# Recursive calls
a_transformed = hadamard_transform_pytorch(a, scale=1.0) # No scaling in intermediate steps
b_transformed = hadamard_transform_pytorch(b, scale=1.0) # No scaling in intermediate steps
# Combine the results
combined = torch.cat([a_transformed + b_transformed, a_transformed - b_transformed], dim=-1)
# Apply the scale only at the final step
return combined * scale
def rotate_activation(x: torch.Tensor) -> torch.Tensor: def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16 assert x.dtype == torch.bfloat16
from fast_hadamard_transform import hadamard_transform #from fast_hadamard_transform import hadamard_transform
hidden_size = x.size(-1) hidden_size = x.size(-1)
assert ( assert (
hidden_size & (hidden_size - 1) hidden_size & (hidden_size - 1)
) == 0, "Hidden size must be a power of 2 for Hadamard transform." ) == 0, "Hidden size must be a power of 2 for Hadamard transform."
return hadamard_transform(x, scale=hidden_size**-0.5) return hadamard_transform_pytorch(x, scale=hidden_size**-0.5)
class V32LayerNorm(nn.Module): class V32LayerNorm(nn.Module):
...@@ -140,7 +172,7 @@ class Indexer(CustomOp): ...@@ -140,7 +172,7 @@ class Indexer(CustomOp):
self.layer_id = layer_id self.layer_id = layer_id
self.alt_stream = alt_stream self.alt_stream = alt_stream
if not is_npu(): if not is_npu():
self.sm_count = deep_gemm.get_num_sms() self.sm_count = torch.cuda.get_device_properties(0).multi_processor_count
self.half_device_sm_count = align(self.sm_count // 2, 8) self.half_device_sm_count = align(self.sm_count // 2, 8)
self.wq_b = ReplicatedLinear( self.wq_b = ReplicatedLinear(
...@@ -273,9 +305,7 @@ class Indexer(CustomOp): ...@@ -273,9 +305,7 @@ class Indexer(CustomOp):
k_rope, _ = torch.split( k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
) )
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope) q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
query[..., : self.rope_head_dim] = q_rope query[..., : self.rope_head_dim] = q_rope
key[..., : self.rope_head_dim] = k_rope key[..., : self.rope_head_dim] = k_rope
...@@ -323,9 +353,9 @@ class Indexer(CustomOp): ...@@ -323,9 +353,9 @@ class Indexer(CustomOp):
blocksize = page_size blocksize = page_size
seqlens_32 = metadata.get_seqlens_int32() seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number # NOTE(dark): 132 is SM count on H200/B200, not magic number
schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata( # schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
seqlens_32, blocksize, self.sm_count # seqlens_32, blocksize, self.sm_count
) # )
assert len(q_fp8.shape) == 3 assert len(q_fp8.shape) == 3
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
...@@ -339,15 +369,13 @@ class Indexer(CustomOp): ...@@ -339,15 +369,13 @@ class Indexer(CustomOp):
assert len(weights.shape) == 3 assert len(weights.shape) == 3
weights = weights.squeeze(2) weights = weights.squeeze(2)
logits = deep_gemm_v32.fp8_paged_mqa_logits( logits = fallback_fp8_paged_mqa_logits(
q_fp8, q_fp8,
kv_cache_fp8, kv_cache_fp8,
weights, weights,
seqlens_32, seqlens_32,
block_tables, block_tables,
schedule_metadata,
max_seq_len, max_seq_len,
clean_logits=False,
) )
# NOTE(dark): logits should be cleaned in topk_transform # NOTE(dark): logits should be cleaned in topk_transform
...@@ -408,13 +436,12 @@ class Indexer(CustomOp): ...@@ -408,13 +436,12 @@ class Indexer(CustomOp):
seq_lens_expanded = metadata.get_seqlens_expanded() seq_lens_expanded = metadata.get_seqlens_expanded()
ke = ks + seq_lens_expanded ke = ks + seq_lens_expanded
logits = deep_gemm_v32.fp8_mqa_logits( logits = fallback_fp8_mqa_logits(
q_fp8, q_fp8,
kv_fp8, k_fp8,
weights, weights,
ks, ks,
ke, ke
clean_logits=False,
) )
assert logits.shape[0] == len(seq_lens_expanded) assert logits.shape[0] == len(seq_lens_expanded)
......
from typing import Optional, Tuple from typing import Optional, Tuple
import tilelang # import tilelang
import tilelang.language as T # import tilelang.language as T
import torch import torch
tilelang.set_log_level("WARNING") # tilelang.set_log_level("WARNING")
pass_configs = { # pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, # tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
} # }
BF16 = "bfloat16" BF16 = "bfloat16"
FP8 = "float8_e4m3" FP8 = "float8_e4m3"
FP32 = "float32" FP32 = "float32"
'''
def fast_log2_ceil(x): def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x) bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF exp_x = (bits_x >> 23) & 0xFF
...@@ -32,7 +32,6 @@ def fast_pow2(x): ...@@ -32,7 +32,6 @@ def fast_pow2(x):
def fast_round_scale(amax, fp8_max_inv): def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs) @tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel( def act_quant_kernel(
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
...@@ -83,7 +82,6 @@ def act_quant_kernel( ...@@ -83,7 +82,6 @@ def act_quant_kernel(
return act_quant_kernel_ return act_quant_kernel_
def act_quant( def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -753,7 +751,6 @@ def sparse_attention_fwd_kernel_v2( ...@@ -753,7 +751,6 @@ def sparse_attention_fwd_kernel_v2(
return main return main
def tilelang_sparse_fwd( def tilelang_sparse_fwd(
q: torch.Tensor, q: torch.Tensor,
kv: torch.Tensor, kv: torch.Tensor,
...@@ -772,3 +769,45 @@ def tilelang_sparse_fwd( ...@@ -772,3 +769,45 @@ def tilelang_sparse_fwd(
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
) )
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
'''
def act_quant(
x: torch.Tensor,
block_size: int = 128,
scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
PyTorch fallback for act_quant
Block-wise FP8 E4M3 quantization
"""
if not x.is_contiguous():
x = x.contiguous()
N = x.size(-1)
assert N % block_size == 0, f"Last dim {N} must be divisible by block_size={block_size}"
# Reshape to blocks
x_2d = x.view(-1, N)
x_blocks = x_2d.view(-1, block_size)
# Compute absmax per block
amax = x_blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-4)
# FP8 E4M3 max value is ~448
fp8_max = 448.0
scale = amax / fp8_max
if scale_fmt is not None:
# Simulate rounded scale (power-of-2 rounding)
scale = torch.round(scale * 256) / 256
# Quantize and clamp
y_blocks = torch.clamp(torch.round(x_blocks / scale), -fp8_max, fp8_max)
# Convert to FP8
q = y_blocks.view_as(x_2d).to(torch.float8_e4m3fn)
# Reshape scale
s = scale.view(x_2d.size(0), N // block_size).to(torch.float32)
s = s.view(*x.shape[:-1], N // block_size)
return q.view_as(x), s
...@@ -105,7 +105,7 @@ def transform_index_page_table_decode_ref( ...@@ -105,7 +105,7 @@ def transform_index_page_table_decode_ref(
torch.gather( torch.gather(
page_table, page_table,
dim=1, dim=1,
index=topk_indices.clamp(min=0), index=topk_indices.clamp(min=0).long(),
out=result, out=result,
) )
result[topk_indices < 0] = -1 result[topk_indices < 0] = -1
......
...@@ -10,7 +10,6 @@ from typing import ( ...@@ -10,7 +10,6 @@ from typing import (
Tuple, Tuple,
TypeAlias, TypeAlias,
Union, Union,
override,
) )
import torch import torch
...@@ -101,19 +100,15 @@ class NSAMetadata: ...@@ -101,19 +100,15 @@ class NSAMetadata:
class NSAIndexerMetadata(BaseIndexerMetadata): class NSAIndexerMetadata(BaseIndexerMetadata):
attn_metadata: NSAMetadata attn_metadata: NSAMetadata
@override
def get_seqlens_int32(self) -> torch.Tensor: def get_seqlens_int32(self) -> torch.Tensor:
return self.attn_metadata.cache_seqlens_int32 return self.attn_metadata.cache_seqlens_int32
@override
def get_page_table_64(self) -> torch.Tensor: def get_page_table_64(self) -> torch.Tensor:
return self.attn_metadata.real_page_table return self.attn_metadata.real_page_table
@override
def get_seqlens_expanded(self) -> torch.Tensor: def get_seqlens_expanded(self) -> torch.Tensor:
return self.attn_metadata.nsa_seqlens_expanded return self.attn_metadata.nsa_seqlens_expanded
@override
def topk_transform( def topk_transform(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
...@@ -524,21 +519,25 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -524,21 +519,25 @@ class NativeSparseAttnBackend(AttentionBackend):
extend_lens_cpu=metadata.nsa_extend_seq_lens_list, extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1, page_size=1,
) )
if NSA_PREFILL_IMPL == "tilelang": # if NSA_PREFILL_IMPL == "tilelang":
from sglang.srt.layers.attention.nsa.tilelang_kernel import ( # from sglang.srt.layers.attention.nsa.tilelang_kernel import (
tilelang_sparse_fwd, # tilelang_sparse_fwd,
) # )
if q_rope is not None: # if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1) # q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_tilelang( # return self._forward_tilelang(
q_all=q_all, # q_all=q_all,
kv_cache=kv_cache, # kv_cache=kv_cache,
page_table_1=page_table_1, # page_table_1=page_table_1,
sm_scale=layer.scaling, # sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim, # v_head_dim=layer.v_head_dim,
) # )
elif NSA_PREFILL_IMPL == "flashmla_prefill": # elif NSA_PREFILL_IMPL == "flashmla_prefill":
# Skip tilelang dependencies
if NSA_PREFILL_IMPL == "tilelang" or NSA_PREFILL_IMPL == "flashmla_prefill":
if q_rope is not None: if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1) q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill( return self._forward_flashmla_prefill(
...@@ -733,9 +732,9 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -733,9 +732,9 @@ class NativeSparseAttnBackend(AttentionBackend):
page_table_1: torch.Tensor, page_table_1: torch.Tensor,
sm_scale: float, sm_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
from flash_mla import flash_mla_sparse_fwd #from flash_mla import flash_mla_sparse_fwd
from sglang.srt.layers.attention.native_mla import native_mla_sparse_fwd
o, _, _ = flash_mla_sparse_fwd( _, _, o = native_mla_sparse_fwd(
q=q_all, q=q_all,
kv=kv_cache, kv=kv_cache,
indices=page_table_1.unsqueeze(1), indices=page_table_1.unsqueeze(1),
...@@ -756,8 +755,8 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -756,8 +755,8 @@ class NativeSparseAttnBackend(AttentionBackend):
topk_indices, topk_indices,
block_table, block_table,
) -> torch.Tensor: ) -> torch.Tensor:
from flash_mla import flash_mla_with_kvcache #from flash_mla import flash_mla_with_kvcache
from sglang.srt.layers.attention.native_mla import native_mla_with_kvcache
cache_seqlens = metadata.nsa_cache_seqlens_int32 cache_seqlens = metadata.nsa_cache_seqlens_int32
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP # TODO the 2nd dim is seq_len_q, need to be >1 when MTP
...@@ -769,7 +768,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -769,7 +768,7 @@ class NativeSparseAttnBackend(AttentionBackend):
# inefficiently quantize the whole cache # inefficiently quantize the whole cache
kv_cache = quantize_k_cache(kv_cache) kv_cache = quantize_k_cache(kv_cache)
o, _ = flash_mla_with_kvcache( o, _ = native_mla_with_kvcache(
q=q_all, q=q_all,
k_cache=kv_cache, k_cache=kv_cache,
cache_seqlens=cache_seqlens, cache_seqlens=cache_seqlens,
......
...@@ -136,21 +136,21 @@ class RMSNorm(CustomOp): ...@@ -136,21 +136,21 @@ class RMSNorm(CustomOp):
# NOTE: Remove this if aiter kernel supports discontinuous input # NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous() x = x.contiguous()
if residual is not None: if residual is not None:
if _vllm_version < Version("0.9"): #if _vllm_version < Version("0.9"):
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual return x, residual
else: # else:
residual_out = torch.empty_like(x) # residual_out = torch.empty_like(x)
output = torch.empty_like(x) # output = torch.empty_like(x)
fused_add_rms_norm( # fused_add_rms_norm(
output, # output,
x, # x,
residual_out, # residual_out,
residual, # residual,
self.weight.data, # self.weight.data,
self.variance_epsilon, # self.variance_epsilon,
) # )
return output, residual_out # return output, residual_out
out = torch.empty_like(x) out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon) rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out return out
......
...@@ -765,7 +765,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -765,7 +765,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin cos_for_key = cos[:, 0, ...]
sin_for_key = sin[:, 0, ...]
key_rot = key_rot * cos_for_key + rotate_fn(key_rot) * sin_for_key
#key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size: if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1) query = torch.cat((query_rot, query_pass), dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment