Unverified Commit 74dd4249 authored by chenxu140's avatar chenxu140 Committed by GitHub
Browse files

[Feature] Support NPUGraph for DeepSeek on Ascend NPU (#9355)


Co-authored-by: default avatarEven Zhou <even.y.zhou@outlook.com>
parent dc20c22f
import concurrent.futures
import logging
from typing import List, Tuple
import numpy as np
import numpy.typing as npt
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
......@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
# Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)
num_layers = len(self.kv_args.kv_data_ptrs)
layers_params = [
(
self.kv_args.kv_data_ptrs[layer_id],
dst_kv_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
)
for layer_id in range(num_layers)
]
def set_transfer_blocks(
src_ptr: int, dst_ptr: int, item_len: int
) -> List[Tuple[int, int, int]]:
transfer_blocks = []
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index)
transfer_blocks.append((src_addr, dst_addr, length))
return transfer_blocks
# Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
return self._transfer_data(mooncake_session_id, transfer_blocks)
# Worker function for processing all layers in a batch
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
transfer_blocks = []
for src_ptr, dst_ptr, item_len in layers_params:
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
return self._transfer_data(mooncake_session_id, transfer_blocks)
if self.enable_custom_mem_pool:
futures = [
executor.submit(
process_layer,
src_ptr,
dst_ptr,
item_len,
)
for (src_ptr, dst_ptr, item_len) in layers_params
]
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
for f in futures:
f.cancel()
return status
else:
# Combining all layers' params in one batch transfer is more efficient
# compared to using multiple threads
return process_layers(layers_params)
return 0
class AscendKVSender(MooncakeKVSender):
pass
......
......@@ -158,7 +158,7 @@ class AscendAttnBackend(AttentionBackend):
self.graph_mode = True
def get_cuda_graph_seq_len_fill_value(self):
return 1
return 0
def forward_extend(
self,
......@@ -167,7 +167,7 @@ class AscendAttnBackend(AttentionBackend):
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
save_kv_cache: bool = True,
):
if not self.use_mla:
if save_kv_cache:
......@@ -253,6 +253,136 @@ class AscendAttnBackend(AttentionBackend):
return attn_output
def forward_decode_graph(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if save_kv_cache:
if self.use_mla:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if not self.use_mla:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
else:
actual_seq_len_kv = (
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
)
num_tokens = query.shape[0]
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=actual_seq_len_kv,
)
output = torch.empty(
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
dtype=q.dtype,
device=q.device,
)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=actual_seq_len_kv,
workspace=workspace,
out=[output, softmax_lse],
)
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_rope_cache = k_rope.view(
-1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
)
c_kv_cache = c_kv.view(
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
)
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
else:
actual_seq_len_kv = (
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
)
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope,
c_kv_cache,
c_kv_cache,
query_rope=q_rope,
key_rope=k_rope_cache,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
input_layout="BNSD",
scale=layer.scaling,
actual_seq_lengths_kv=actual_seq_len_kv,
antiquant_mode=0,
antiquant_scale=None,
sparse_mode=0,
)
output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
c_kv_cache,
c_kv_cache,
query_rope=q_rope,
key_rope=k_rope_cache,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
input_layout="BNSD",
scale=layer.scaling,
actual_seq_lengths_kv=actual_seq_len_kv,
antiquant_mode=0,
antiquant_scale=None,
sparse_mode=0,
workspace=workspace,
out=[output, softmax_lse],
)
return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank)
def forward_decode(
self,
q: torch.Tensor,
......@@ -260,106 +390,73 @@ class AscendAttnBackend(AttentionBackend):
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = False,
save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if self.graph_mode:
return self.forward_decode_graph(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope=q_rope,
k_rope=k_rope,
)
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
num_tokens = q.shape[0]
if self.graph_mode:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
workspace = (
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
scale=layer.scaling,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
)
)
attn_output = torch.empty(
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
dtype=q.dtype,
device=q.device,
)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out(
query,
k_cache,
v_cache,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
if self.use_fia:
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q.view(
forward_batch.batch_size,
-1,
layer.tp_q_head_num,
layer.qk_head_dim,
),
k_cache.view(
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
),
v_cache.view(
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
),
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSH",
input_layout="BSND",
atten_mask=None,
block_size=self.page_size,
block_table=self.forward_metadata.block_tables,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
scale=layer.scaling,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
workspace=workspace,
out=[attn_output, softmax_lse],
)
else:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
attn_output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
if self.use_fia:
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q.view(
forward_batch.batch_size,
-1,
layer.tp_q_head_num,
layer.qk_head_dim,
),
k_cache.view(
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
),
v_cache.view(
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
),
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSND",
atten_mask=None,
block_size=self.page_size,
block_table=self.forward_metadata.block_tables,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
scale=layer.scaling,
)
else:
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
attn_output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
value_cache=v_cache,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
out=attn_output,
)
torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
value_cache=v_cache,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
out=attn_output,
)
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
if save_kv_cache:
......@@ -370,9 +467,7 @@ class AscendAttnBackend(AttentionBackend):
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
if (self.graph_mode or self.use_fia) and (
layer.tp_q_head_num // layer.tp_k_head_num
) >= 8:
if self.use_fia and (layer.tp_q_head_num // layer.tp_k_head_num) >= 8:
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
kv_c = kv_c.view(
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
......
......@@ -746,19 +746,25 @@ class DeepEPMoE(EPMoE):
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=output_dtype,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=seg_indptr,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
......
......@@ -304,12 +304,12 @@ class TopK(CustomOp):
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256 and self.topk_config.renormalize is True:
if global_num_experts == 256:
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32)
return torch_npu.npu_moe_gating_top_k(
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=self.topk_config.top_k,
bias=self.topk_config.correction_bias.to(torch.float32),
......@@ -321,6 +321,16 @@ class TopK(CustomOp):
routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20),
)
if self.topk_config.renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if self.topk_config.num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return StandardTopKOutput(topk_weights, topk_ids, _)
else:
self.topk_config.torch_native = True
return select_experts(
......
......@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
return params_dict
@staticmethod
......@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
if original_dtype != torch.int8:
x = torch_npu.npu_quantize(
x,
layer.aclnn_input_scale,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
torch.qint8,
-1,
True,
False,
)
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
......@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
requires_grad=False,
)
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
requires_grad=False,
)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
requires_grad=False,
......
......@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
layer_num,
self.size // self.page_size + 1,
self.page_size,
1,
self.kv_lora_rank,
),
dtype=self.store_dtype,
......@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
layer_num,
self.size // self.page_size + 1,
self.page_size,
1,
self.qk_rope_head_dim,
),
dtype=self.store_dtype,
......@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if cache_v is None:
cache_k, cache_v = cache_k.split(
......
......@@ -114,6 +114,7 @@ from sglang.srt.utils import (
is_flashinfer_available,
is_hip,
is_non_idle_and_non_empty,
is_npu,
is_sm100_supported,
log_info_on_rank0,
make_layers,
......@@ -122,6 +123,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
......@@ -1181,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
if not _is_npu:
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
)
else:
# To reduce a time-costing split operation
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
)
return q, k, v, forward_batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment