Unverified Commit 1e0e5497 authored by ronnie_zheng's avatar ronnie_zheng Committed by GitHub
Browse files

Ascend attention backend(PA&MLA) (#7722)


Co-authored-by: default avatarMaksim <makcum888e@mail.ru>
Co-authored-by: default avatarVDV1985 <vladdv85@mail.ru>
parent b5822651
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ | | **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ |
| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ | | **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ | | **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
| **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ |
Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`. Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend. This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
...@@ -46,3 +47,8 @@ python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct -- ...@@ -46,3 +47,8 @@ python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
``` ```
- Ascend
```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
```
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
import torch_npu
from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
@dataclass
class ForwardMetadata:
# calculated map for kv positions [bs * maxseqlen]
block_tables: Optional[torch.Tensor] = None
# seq len inputs
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_int: Optional[torch.Tensor] = None
class AscendAttnBackend(AttentionBackend):
def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16):
mask_flag = torch.tril(
torch.ones((max_seq_len, max_seq_len), dtype=torch.bool)
).view(max_seq_len, max_seq_len)
mask_flag = ~mask_flag
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
self.mask = (
torch.masked_fill(
torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value
)
.to(dtype)
.to(self.device)
)
self.mask_len = max_seq_len
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata = ForwardMetadata()
self.device = model_runner.device
self.gen_attention_mask(128, model_runner.dtype)
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.native_attn = TorchNativeAttnBackend(model_runner)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
][:, :: self.page_size]
// self.page_size
)
if forward_batch.extend_seq_lens is not None:
self.forward_metadata.extend_seq_lens_cpu_int = (
forward_batch.extend_seq_lens.cpu().int()
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
def forward_extend(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
if not self.use_mla:
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
output = torch.empty(
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=k_cache,
value_cache=v_cache,
mask=self.mask,
block_table=self.forward_metadata.block_tables,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
context_lens=self.forward_metadata.seq_lens_cpu_int,
scale_value=layer.scaling,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
out=output,
)
return output
else:
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
causal = True
if (
layer.is_cross_attention
or layer.attn_type == AttentionType.ENCODER_ONLY
):
causal = False
self.native_attn._run_sdpa_forward_extend(
q_,
o_,
k_cache.view(
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
),
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_prefix_lens,
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=causal,
)
return o
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if not self.use_mla:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
num_tokens = query.shape[0]
output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
value_cache=v_cache,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
out=output,
)
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
num_tokens = query.shape[0]
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1,
self.page_size,
layer.tp_k_head_num,
self.kv_lora_rank + self.qk_rope_head_dim,
)
attn_output = torch.empty(
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=layer.tp_k_head_num,
num_heads=layer.tp_q_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
mla_vheadsize=self.kv_lora_rank,
out=attn_output,
)
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
...@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple ...@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
import einops import einops
import torch import torch
from sgl_kernel import silu_and_mul
from torch.nn import Module from torch.nn import Module
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
...@@ -50,13 +49,18 @@ from sglang.srt.utils import ( ...@@ -50,13 +49,18 @@ from sglang.srt.utils import (
dispose_tensor, dispose_tensor,
get_bool_env_var, get_bool_env_var,
is_hip, is_hip,
is_npu,
set_weight_attrs, set_weight_attrs,
) )
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not _is_npu:
from sgl_kernel import silu_and_mul
if _is_hip: if _is_hip:
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
......
...@@ -321,6 +321,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -321,6 +321,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor, routed_scaling_factor,
) )
def forward_npu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor: def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.") raise NotImplementedError("The TPU backend currently does not support MoE.")
......
...@@ -35,6 +35,7 @@ from sglang.srt.utils import ( ...@@ -35,6 +35,7 @@ from sglang.srt.utils import (
is_cpu, is_cpu,
is_cuda, is_cuda,
is_hip, is_hip,
is_npu,
) )
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -42,6 +43,7 @@ _is_hip = is_hip() ...@@ -42,6 +43,7 @@ _is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_npu = is_npu()
if _is_cuda: if _is_cuda:
from sgl_kernel import moe_fused_gate from sgl_kernel import moe_fused_gate
...@@ -159,6 +161,9 @@ def grouped_topk_gpu( ...@@ -159,6 +161,9 @@ def grouped_topk_gpu(
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1) scores = torch.softmax(gating_output, dim=-1)
# NPU compiler limitation
if _is_npu and scores.dtype == torch.bfloat16:
scores = scores.to(torch.float16)
num_token = scores.shape[0] num_token = scores.shape[0]
num_experts = scores.shape[1] num_experts = scores.shape[1]
group_scores = ( group_scores = (
......
...@@ -660,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -660,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1, beta_slow: int = 1,
mscale: float = 1, mscale: float = 1,
mscale_all_dim: float = 0, mscale_all_dim: float = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda" if not _is_npu else "npu",
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
...@@ -679,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -679,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
) )
# Re-dispatch # Re-dispatch
if _is_hip: if _is_hip or _is_npu:
self._forward_method = self.forward_native self._forward_method = self.forward_native
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
......
...@@ -1673,6 +1673,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1673,6 +1673,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
or global_server_args_dict["attention_backend"] == "flashmla" or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "cutlass_mla" or global_server_args_dict["attention_backend"] == "cutlass_mla"
or global_server_args_dict["attention_backend"] == "ascend"
or global_server_args_dict["enable_two_batch_overlap"] or global_server_args_dict["enable_two_batch_overlap"]
): ):
seq_lens_cpu = ( seq_lens_cpu = (
...@@ -1875,7 +1876,10 @@ def get_last_loc( ...@@ -1875,7 +1876,10 @@ def get_last_loc(
req_pool_indices_tensor: torch.Tensor, req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if global_server_args_dict["attention_backend"] != "torch_native": if (
global_server_args_dict["attention_backend"] != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native"
):
impl = get_last_loc_triton impl = get_last_loc_triton
else: else:
impl = get_last_loc_torch impl = get_last_loc_torch
......
...@@ -540,3 +540,164 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -540,3 +540,164 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
) )
self.is_not_in_free_group = True self.is_not_in_free_group = True
self.free_group = [] self.free_group = []
def alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
device,
):
extend_lens = seq_lens - prefix_lens
end_pos = torch.cumsum(extend_lens, 0)
start_pos = end_pos - extend_lens
num_new_pages = (seq_lens + page_size - 1) // page_size - (
prefix_lens + page_size - 1
) // page_size
num_full_new_pages = (seq_lens) // page_size - (
prefix_lens + page_size - 1
) // page_size
need_page = num_new_pages - num_full_new_pages
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
for i in range(len(prefix_lens)):
num1 = (
min(
seq_lens[i],
(prefix_lens[i] + page_size - 1) // page_size * page_size,
)
- prefix_lens[i]
)
if num1:
out_indices[start_pos[i] : start_pos[i] + num1] = (
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
)
num2 = (
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
) * page_size
if num2:
pages = (
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
* page_size
)
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
pages.view(-1, 1) + pos_in_page.view(1, -1)
).view(-1)
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
if num3:
out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1)
return num_new_pages
def alloc_decode_kernel_ascend(
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
):
num_new_pages = (seq_lens + page_size - 1) // page_size - (
seq_lens - 1 + page_size - 1
) // page_size
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
for i in range(len(seq_lens)):
if num_new_pages[i]:
out_indices[i] = free_pages[start_new_pages[i]] * page_size
else:
out_indices[i] = last_loc[i] + 1
return num_new_pages
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
super().__init__(size, page_size, dtype, device, kvcache)
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
self.ret_values = alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.sum()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
self.ret_values = alloc_decode_kernel_ascend(
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.sum()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def clear(self):
super().clear()
self.free_pages = self.free_pages.to(torch.int32)
...@@ -568,6 +568,76 @@ class SWAKVPool(KVCache): ...@@ -568,6 +568,76 @@ class SWAKVPool(KVCache):
) )
class AscendTokenToKVPool(MHATokenToKVPool):
def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.head_num,
self.head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.head_num,
self.head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
import torch_npu
torch_npu._npu_reshape_and_cache(
key=cache_k,
value=cache_v,
key_cache=self.k_buffer[layer_id].view(
-1, self.page_size, self.head_num, self.head_dim
),
value_cache=self.v_buffer[layer_id].view(
-1, self.page_size, self.head_num, self.head_dim
),
slot_indices=loc,
)
@triton.jit @triton.jit
def set_mla_kv_buffer_kernel( def set_mla_kv_buffer_kernel(
kv_buffer_ptr, kv_buffer_ptr,
...@@ -820,6 +890,84 @@ class MLATokenToKVPool(KVCache): ...@@ -820,6 +890,84 @@ class MLATokenToKVPool(KVCache):
torch.cuda.synchronize() torch.cuda.synchronize()
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
super(MLATokenToKVPool, self).__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.custom_mem_pool = None
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(
self.size // self.page_size + 1,
self.page_size,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(layer_num)
]
self.layer_transfer_counter = None
kv_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
)
self.mem_usage = kv_size / GB
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(store_dtype)
import torch_npu
torch_npu._npu_reshape_and_cache_siso(
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
),
slot_indices=loc,
)
class DoubleSparseTokenToKVPool(KVCache): class DoubleSparseTokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
......
...@@ -39,7 +39,12 @@ import triton ...@@ -39,7 +39,12 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton from sglang.srt.utils import (
flatten_nested_list,
get_compiler_backend,
is_npu,
support_triton,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -50,6 +55,8 @@ if TYPE_CHECKING: ...@@ -50,6 +55,8 @@ if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
_is_npu = is_npu()
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt). # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
...@@ -739,7 +746,7 @@ def compute_position_torch( ...@@ -739,7 +746,7 @@ def compute_position_torch(
return positions.to(torch.int64), extend_start_loc return positions.to(torch.int64), extend_start_loc
@torch.compile(dynamic=True, backend=get_compiler_backend()) @torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def clamp_position(seq_lens): def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64) return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
......
...@@ -72,12 +72,15 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -72,12 +72,15 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
AscendPagedTokenToKVPoolAllocator,
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator, SWATokenToKVPoolAllocator,
TokenToKVPoolAllocator, TokenToKVPoolAllocator,
) )
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
AscendMLAPagedTokenToKVPool,
AscendTokenToKVPool,
DoubleSparseTokenToKVPool, DoubleSparseTokenToKVPool,
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
...@@ -110,6 +113,7 @@ from sglang.srt.utils import ( ...@@ -110,6 +113,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_hopper_with_cuda_12_3, is_hopper_with_cuda_12_3,
is_no_spec_infer_or_topk_one, is_no_spec_infer_or_topk_one,
is_npu,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
...@@ -117,6 +121,7 @@ from sglang.srt.utils import ( ...@@ -117,6 +121,7 @@ from sglang.srt.utils import (
) )
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
# Use a small KV cache pool size for tests in CI # Use a small KV cache pool size for tests in CI
...@@ -308,6 +313,7 @@ class ModelRunner: ...@@ -308,6 +313,7 @@ class ModelRunner:
self.init_cuda_graphs() self.init_cuda_graphs()
else: else:
self.cuda_graph_runner = None self.cuda_graph_runner = None
self.cuda_graph_mem_usage = 0
self.init_attention_backend() self.init_attention_backend()
# auxiliary hidden capture mode. TODO: expose this to server args? # auxiliary hidden capture mode. TODO: expose this to server args?
...@@ -369,6 +375,8 @@ class ModelRunner: ...@@ -369,6 +375,8 @@ class ModelRunner:
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
elif _is_hip: elif _is_hip:
server_args.attention_backend = "aiter" server_args.attention_backend = "aiter"
elif _is_npu:
server_args.attention_backend = "ascend"
else: else:
server_args.attention_backend = ( server_args.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton" "flashinfer" if is_flashinfer_available() else "triton"
...@@ -388,6 +396,8 @@ class ModelRunner: ...@@ -388,6 +396,8 @@ class ModelRunner:
server_args.attention_backend = "aiter" server_args.attention_backend = "aiter"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
elif _is_npu:
server_args.attention_backend = "ascend"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
logger.info( logger.info(
...@@ -402,6 +412,7 @@ class ModelRunner: ...@@ -402,6 +412,7 @@ class ModelRunner:
"triton", "triton",
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
"ascend",
]: ]:
logger.info( logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend." f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
...@@ -1096,7 +1107,35 @@ class ModelRunner: ...@@ -1096,7 +1107,35 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker. # Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker assert self.is_draft_worker
if self.use_mla_backend: if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
self.token_to_kv_pool = AscendTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=(
self.model_config.num_hidden_layers
if not self.is_draft_worker
else self.model_config.hf_config.num_nextn_predict_layers
), # PP is not compatible with mla backend
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
elif self.use_mla_backend:
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
...@@ -1175,6 +1214,15 @@ class ModelRunner: ...@@ -1175,6 +1214,15 @@ class ModelRunner:
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
) )
else:
if _is_npu:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else: else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -1229,6 +1277,10 @@ class ModelRunner: ...@@ -1229,6 +1277,10 @@ class ModelRunner:
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
return AiterAttnBackend(self) return AiterAttnBackend(self)
elif self.server_args.attention_backend == "ascend":
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
return AscendAttnBackend(self)
elif self.server_args.attention_backend == "triton": elif self.server_args.attention_backend == "triton":
assert not self.model_config.is_encoder_decoder, ( assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. " "Cross attention is not supported in the triton attention backend. "
......
...@@ -956,7 +956,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -956,7 +956,9 @@ class DeepseekV2AttentionMLA(nn.Module):
else: else:
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
if self.attention_backend == "flashinfer": if self.attention_backend == "ascend":
return AttnForwardMethod.MLA
elif self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
if ( if (
not self.flashinfer_mla_disable_ragged not self.flashinfer_mla_disable_ragged
......
...@@ -380,6 +380,12 @@ class ServerArgs: ...@@ -380,6 +380,12 @@ class ServerArgs:
) )
self.disable_cuda_graph = True self.disable_cuda_graph = True
if self.attention_backend == "ascend":
logger.warning(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
)
self.page_size = 128
# Choose grammar backend # Choose grammar backend
if self.grammar_backend is None: if self.grammar_backend is None:
self.grammar_backend = "xgrammar" self.grammar_backend = "xgrammar"
...@@ -1113,6 +1119,7 @@ class ServerArgs: ...@@ -1113,6 +1119,7 @@ class ServerArgs:
"flashmla", "flashmla",
"intel_amx", "intel_amx",
"torch_native", "torch_native",
"ascend",
"triton", "triton",
], ],
default=ServerArgs.attention_backend, default=ServerArgs.attention_backend,
......
...@@ -2399,7 +2399,7 @@ def bind_or_assign(target, source): ...@@ -2399,7 +2399,7 @@ def bind_or_assign(target, source):
def support_triton(backend: str) -> bool: def support_triton(backend: str) -> bool:
return backend not in ["torch_native", "intel_amx"] return backend not in ["torch_native", "intel_amx", "ascend"]
try: try:
......
...@@ -143,6 +143,9 @@ suites = { ...@@ -143,6 +143,9 @@ suites = {
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
TestFile("test_reasoning_parser.py", 5), TestFile("test_reasoning_parser.py", 5),
], ],
"per-commit-npu": [
TestFile("test_ascend_attention_backend.py", 400),
],
"per-commit-2-gpu": [ "per-commit-2-gpu": [
TestFile("models/lora/test_lora_tp.py", 116), TestFile("models/lora/test_lora_tp.py", 116),
TestFile("test_data_parallelism.py", 73), TestFile("test_data_parallelism.py", 73),
......
"""
Usage:
python3 -m unittest test_ascend_attention_backend.TestAscendAttnBackend.test_gsm8k
"""
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
class TestAscendAttnBackend(CustomTestCase):
def test_latency(self):
output_throughput = run_bench_offline_throughput(
DEFAULT_MODEL_NAME_FOR_TEST,
[
"--attention-backend",
"ascend",
],
)
print(f"{output_throughput=}")
if is_in_ci():
self.assertGreater(output_throughput, 18)
def test_gsm8k(self):
model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
url = urlparse(base_url)
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--attention-backend",
"ascend",
"--mem-fraction-static",
0.8,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{url.hostname}",
port=int(url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(metrics["accuracy"], 0.62)
self.assertLessEqual(metrics["latency"], 150)
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()
"""
Usage:
python3 -m unittest test_ascend_mla_backend.TestAscendMLABackend.test_gsm8k
"""
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ:
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1,2,3"
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100
)
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
DEFAULT_MODEL_NAME_FOR_TEST = "/models/DeepSeek-V2-Lite-Chat"
if not os.path.exists(DEFAULT_MODEL_NAME_FOR_TEST):
DEFAULT_MODEL_NAME_FOR_TEST = DEFAULT_MLA_MODEL_NAME_FOR_TEST
class TestAscendMLABackend(CustomTestCase):
def test_latency(self):
output_throughput = run_bench_offline_throughput(
DEFAULT_MODEL_NAME_FOR_TEST,
[
"--attention-backend",
"ascend",
"--mem-fraction-static",
0.7,
"--tp-size",
"4",
"--trust-remote-code",
"--disable-cuda-graph",
],
)
print(f"{output_throughput=}")
if is_in_ci():
self.assertGreater(output_throughput, 18)
def test_gsm8k(self):
model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
url = urlparse(base_url)
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--attention-backend",
"ascend",
"--mem-fraction-static",
0.7,
"--tp-size",
"4",
"--trust-remote-code",
"--disable-cuda-graph",
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=128,
max_new_tokens=512,
parallel=128,
host=f"http://{url.hostname}",
port=int(url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(metrics["accuracy"], 0.62)
self.assertGreaterEqual(metrics["output_throughput"], 50)
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment