Unverified Commit f92b729d authored by ZhengdQin's avatar ZhengdQin Committed by GitHub
Browse files

[new feat] ascend backend support fia fusion kernel (#8328)


Co-authored-by: default avatarEven Zhou <even.y.zhou@outlook.com>
parent e2e378ca
...@@ -47,7 +47,7 @@ jobs: ...@@ -47,7 +47,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
- name: Run test - name: Run test
timeout-minutes: 30 timeout-minutes: 60
env: env:
SGLANG_USE_MODELSCOPE: true SGLANG_USE_MODELSCOPE: true
SGLANG_IS_IN_CI: true SGLANG_IS_IN_CI: true
...@@ -82,7 +82,7 @@ jobs: ...@@ -82,7 +82,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
- name: Run test - name: Run test
timeout-minutes: 30 timeout-minutes: 90
env: env:
SGLANG_USE_MODELSCOPE: true SGLANG_USE_MODELSCOPE: true
SGLANG_IS_IN_CI: true SGLANG_IS_IN_CI: true
...@@ -117,7 +117,7 @@ jobs: ...@@ -117,7 +117,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
- name: Run test - name: Run test
timeout-minutes: 30 timeout-minutes: 60
env: env:
SGLANG_USE_MODELSCOPE: true SGLANG_USE_MODELSCOPE: true
SGLANG_IS_IN_CI: true SGLANG_IS_IN_CI: true
......
...@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend ...@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_bool_env_var
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
import os
import numpy as np
@dataclass @dataclass
class ForwardMetadata: class ForwardMetadata:
...@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend): ...@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend):
super().__init__() super().__init__()
self.forward_metadata = None self.forward_metadata = None
self.device = model_runner.device self.device = model_runner.device
self.gen_attention_mask(128, model_runner.dtype)
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
if self.use_mla: if self.use_mla:
...@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend): ...@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.graph_mode = False self.graph_mode = False
self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
if not self.use_fia:
self.gen_attention_mask(128, model_runner.dtype)
mask_length = 2048
self.fia_mask = ~torch.tril(
torch.ones(
(mask_length, mask_length),
dtype=torch.bool,
device=model_runner.device,
)
)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
...@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend): ...@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
forward_batch.extend_seq_lens.cpu().int() forward_batch.extend_seq_lens.cpu().int()
) )
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
forward_batch.extend_seq_lens_cpu
)
self.graph_mode = False self.graph_mode = False
...@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend): ...@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
if save_kv_cache: if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer( if save_kv_cache:
layer, forward_batch.out_cache_loc, k, v forward_batch.token_to_kv_pool.set_kv_buffer(
) layer, forward_batch.out_cache_loc, k, v
)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
if not self.use_mla: if self.use_fia:
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim) """FIA will support multi-bs in the later version of CANN"""
output = torch.empty( q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim), attn_output = torch.empty(
dtype=query.dtype, (q.size(0), layer.tp_q_head_num, layer.v_head_dim),
device=query.device, device=q.device,
) dtype=q.dtype,
)
q_len_offset = 0
for q_len in forward_batch.extend_seq_lens_cpu:
attn_output[q_len_offset : q_len_offset + q_len] = (
torch.ops.npu.npu_fused_infer_attention_score(
q[None, q_len_offset : q_len_offset + q_len],
k[None, q_len_offset : q_len_offset + q_len],
v[None, q_len_offset : q_len_offset + q_len],
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSND", # todo, TND not supports q_heads!=k_heads
atten_mask=self.fia_mask.unsqueeze(0),
sparse_mode=3,
scale=layer.scaling,
next_tokens=0,
)[0]
)
q_len_offset += q_len
attn_output = attn_output.view(
-1, layer.tp_q_head_num * layer.v_head_dim
)
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=k_cache,
value_cache=v_cache,
mask=self.mask,
block_table=self.forward_metadata.block_tables,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
context_lens=self.forward_metadata.seq_lens_cpu_int,
scale_value=layer.scaling,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
out=output,
)
return output
else:
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else: else:
o = torch.empty_like(q) query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
attn_output = torch.empty(
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
dtype=query.dtype,
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) device=query.device,
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim) )
causal = True torch_npu._npu_flash_attention_qlens(
if ( query=query,
layer.is_cross_attention key_cache=k_cache,
or layer.attn_type == AttentionType.ENCODER_ONLY value_cache=v_cache,
): mask=self.mask,
causal = False block_table=self.forward_metadata.block_tables,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
self.native_attn._run_sdpa_forward_extend( context_lens=self.forward_metadata.seq_lens_cpu_int,
q_, scale_value=layer.scaling,
o_, num_heads=layer.tp_q_head_num,
k_cache.view( num_kv_heads=layer.tp_k_head_num,
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim) out=attn_output,
), )
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank), else:
forward_batch.req_to_token_pool.req_to_token, assert (
forward_batch.req_pool_indices, layer.qk_head_dim != layer.v_head_dim
forward_batch.seq_lens, ), "FIA only supports qk_head_dim != v_head_dim"
forward_batch.extend_prefix_lens, q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
forward_batch.extend_seq_lens, k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
scaling=layer.scaling,
enable_gqa=use_gqa, attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
causal=causal, q_nope,
k_nope,
v,
query_rope=q_rope,
key_rope=k_rope,
num_heads=layer.tp_q_head_num,
input_layout="TND",
atten_mask=self.fia_mask,
sparse_mode=3,
actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
scale=layer.scaling,
next_tokens=0,
) )
return o
return attn_output
def forward_decode( def forward_decode(
self, self,
...@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend): ...@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend):
v: torch.Tensor, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache: bool = False,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
): ):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if not self.use_mla: if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
num_tokens = q.shape[0]
if self.graph_mode: if self.graph_mode:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer( k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id layer.layer_id
...@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend): ...@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend):
layer.layer_id layer.layer_id
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
num_tokens = query.shape[0]
workspace = ( workspace = (
torch_npu._npu_fused_infer_attention_score_get_max_workspace( torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query, query,
...@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend): ...@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend):
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
) )
) )
output = torch.empty( attn_output = torch.empty(
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
dtype=q.dtype, dtype=q.dtype,
device=q.device, device=q.device,
...@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend): ...@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend):
scale=layer.scaling, scale=layer.scaling,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
workspace=workspace, workspace=workspace,
out=[output, softmax_lse], out=[attn_output, softmax_lse],
) )
else: else:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer( v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id layer.layer_id
) )
if self.use_fia:
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q.view(
forward_batch.batch_size,
-1,
layer.tp_q_head_num,
layer.qk_head_dim,
),
k_cache.view(
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
),
v_cache.view(
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
),
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSND",
atten_mask=None,
block_size=self.page_size,
block_table=self.forward_metadata.block_tables,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
scale=layer.scaling,
)
else:
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
attn_output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) torch_npu._npu_paged_attention(
num_tokens = query.shape[0] query=query,
output = torch.empty( key_cache=k_cache,
(num_tokens, layer.tp_q_head_num, layer.v_head_dim), value_cache=v_cache,
dtype=query.dtype, num_heads=layer.tp_q_head_num,
device=query.device, num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
out=attn_output,
)
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
) )
num_tokens = q.shape[0]
torch_npu._npu_paged_attention( kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
query=query, k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
key_cache=k_cache,
value_cache=v_cache, if (self.graph_mode or self.use_fia) and (
layer.tp_q_head_num // layer.tp_k_head_num
) >= 8:
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
kv_c = kv_c.view(
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
)
k_pe = k_pe.view(
-1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim
)
q = q.view(
forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank
)
q_rope = q_rope.view(
forward_batch.batch_size,
-1,
layer.tp_q_head_num,
self.qk_rope_head_dim,
)
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q,
kv_c,
kv_c,
query_rope=q_rope,
key_rope=k_pe,
num_heads=layer.tp_q_head_num, num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSND",
atten_mask=None,
sparse_mode=0,
scale=layer.scaling,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
)
else:
assert (
self.graph_mode == False
) # _npu_paged_attention_mla not support graph mode
q = torch.cat([q, q_rope], dim=-1)
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1,
self.page_size,
layer.tp_k_head_num,
self.kv_lora_rank + self.qk_rope_head_dim,
)
attn_output = torch.empty(
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=layer.tp_k_head_num, num_kv_heads=layer.tp_k_head_num,
num_heads=layer.tp_q_head_num,
scale_value=layer.scaling, scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables, block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int, context_lens=self.forward_metadata.seq_lens_cpu_int,
out=output, mla_vheadsize=self.kv_lora_rank,
out=attn_output,
) )
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
num_tokens = query.shape[0]
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1,
self.page_size,
layer.tp_k_head_num,
self.kv_lora_rank + self.qk_rope_head_dim,
)
attn_output = torch.empty(
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=layer.tp_k_head_num,
num_heads=layer.tp_q_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
mla_vheadsize=self.kv_lora_rank,
out=attn_output,
)
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank) return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
...@@ -304,7 +304,7 @@ class TopK(CustomOp): ...@@ -304,7 +304,7 @@ class TopK(CustomOp):
global_num_experts = router_logits.shape[-1] global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256 and self.topk_config.renormalize is False: if global_num_experts == 256 and self.topk_config.renormalize is True:
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1 routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32) router_logits = router_logits.to(torch.float32)
......
...@@ -36,12 +36,15 @@ import triton.language as tl ...@@ -36,12 +36,15 @@ import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2 from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024 GB = 1024 * 1024 * 1024
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu()
if _is_npu:
import torch_npu
class ReqToTokenPool: class ReqToTokenPool:
...@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool): ...@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
cache_k = cache_k.view(self.store_dtype) cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype)
import torch_npu
torch_npu._npu_reshape_and_cache( torch_npu._npu_reshape_and_cache(
key=cache_k, key=cache_k,
value=cache_v, value=cache_v,
...@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ...@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = torch.zeros( self.k_buffer = torch.zeros(
( (
layer_num, layer_num,
self.size // self.page_size + 1, self.size // self.page_size + 1,
self.page_size, self.page_size,
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank,
),
dtype=self.store_dtype,
device=self.device,
)
self.v_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
self.qk_rope_head_dim,
), ),
dtype=self.store_dtype, dtype=self.store_dtype,
device=self.device, device=self.device,
...@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ...@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
) )
self.mem_usage = kv_size / GB self.mem_usage = kv_size / GB
def get_kv_size_bytes(self):
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
kv_size_bytes = 0
for k_cache in self.k_buffer:
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
for v_cache in self.v_buffer:
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
return kv_size_bytes
def get_kv_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return (
self.k_buffer[layer_id - self.start_layer],
self.v_buffer[layer_id - self.start_layer],
)
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
# for disagg # for disagg
def get_contiguous_buf_infos(self): def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned. # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)] kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] self.v_buffer[i].data_ptr() for i in range(self.layer_num)
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)] ]
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i].nbytes for i in range(self.layer_num)
]
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer( def set_kv_buffer(
...@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ...@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype) cache_k = cache_k.view(self.store_dtype)
import torch_npu if cache_v is None:
cache_k, cache_v = cache_k.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
torch_npu._npu_reshape_and_cache_siso( torch_npu.npu_scatter_nd_update_(
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim), self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
key_cache=self.kv_buffer[layer_id - self.start_layer].view( loc.view(-1, 1),
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim cache_k.view(-1, 1, self.kv_lora_rank),
)
torch_npu.npu_scatter_nd_update_(
self.v_buffer[layer_id - self.start_layer].view(
-1, 1, self.qk_rope_head_dim
), ),
slot_indices=loc, loc.view(-1, 1),
cache_v.view(-1, 1, self.qk_rope_head_dim),
) )
......
...@@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module):
self.current_attention_backend = attention_backend self.current_attention_backend = attention_backend
if attention_backend == "ascend": if attention_backend == "ascend":
return AttnForwardMethod.MLA if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
elif ( elif (
attention_backend == "flashinfer" attention_backend == "flashinfer"
or attention_backend == "fa3" or attention_backend == "fa3"
...@@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module):
or self.current_attention_backend == "flashinfer" or self.current_attention_backend == "flashinfer"
or self.current_attention_backend == "cutlass_mla" or self.current_attention_backend == "cutlass_mla"
or self.current_attention_backend == "trtllm_mla" or self.current_attention_backend == "trtllm_mla"
or self.current_attention_backend == "ascend"
): ):
extra_args = {} extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch): if self._fuse_rope_for_trtllm_mla(forward_batch):
......
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": {
"accuracy": 0.34,
"latency": 1000,
"output_throughput": 6,
},
}
class TestAscendMlaW8A8Int8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--quantization",
"w8a8_int8",
"--tp-size",
2,
"--disable-radix-cache",
]
def test_a_gsm8k(self):
os.environ["ASCEND_USE_FIA"] = "true"
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
...@@ -40,6 +40,7 @@ class TestAscendMlaW8A8Int8(CustomTestCase): ...@@ -40,6 +40,7 @@ class TestAscendMlaW8A8Int8(CustomTestCase):
"w8a8_int8", "w8a8_int8",
"--tp-size", "--tp-size",
4, 4,
"--disable-radix-cache",
] ]
def test_a_gsm8k(self): def test_a_gsm8k(self):
......
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.85,
"latency": 180,
"output_throughput": 20,
},
}
class TestAscendTp2Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--tp-size",
2,
"--disable-radix-cache",
]
def test_a_gsm8k(self):
os.environ["ASCEND_USE_FIA"] = "true"
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
...@@ -275,6 +275,8 @@ suite_ascend = { ...@@ -275,6 +275,8 @@ suite_ascend = {
"per-commit-2-ascend-npu": [ "per-commit-2-ascend-npu": [
TestFile("ascend/test_ascend_tp2_bf16.py", 400), TestFile("ascend/test_ascend_tp2_bf16.py", 400),
TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400), TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400),
TestFile("ascend/test_ascend_tp2_fia_bf16.py", 400),
TestFile("ascend/test_ascend_mla_fia_w8a8int8.py", 400),
], ],
"per-commit-4-ascend-npu": [ "per-commit-4-ascend-npu": [
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400), TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment