Unverified Commit f92b729d authored by ZhengdQin's avatar ZhengdQin Committed by GitHub
Browse files

[new feat] ascend backend support fia fusion kernel (#8328)


Co-authored-by: default avatarEven Zhou <even.y.zhou@outlook.com>
parent e2e378ca
......@@ -47,7 +47,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
- name: Run test
timeout-minutes: 30
timeout-minutes: 60
env:
SGLANG_USE_MODELSCOPE: true
SGLANG_IS_IN_CI: true
......@@ -82,7 +82,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
- name: Run test
timeout-minutes: 30
timeout-minutes: 90
env:
SGLANG_USE_MODELSCOPE: true
SGLANG_IS_IN_CI: true
......@@ -117,7 +117,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
- name: Run test
timeout-minutes: 30
timeout-minutes: 60
env:
SGLANG_USE_MODELSCOPE: true
SGLANG_IS_IN_CI: true
......
......@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
import os
import numpy as np
@dataclass
class ForwardMetadata:
......@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend):
super().__init__()
self.forward_metadata = None
self.device = model_runner.device
self.gen_attention_mask(128, model_runner.dtype)
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
if self.use_mla:
......@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.graph_mode = False
self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
if not self.use_fia:
self.gen_attention_mask(128, model_runner.dtype)
mask_length = 2048
self.fia_mask = ~torch.tril(
torch.ones(
(mask_length, mask_length),
dtype=torch.bool,
device=model_runner.device,
)
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
......@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
forward_batch.extend_seq_lens.cpu().int()
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
forward_batch.extend_seq_lens_cpu
)
self.graph_mode = False
......@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
if not self.use_mla:
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
output = torch.empty(
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
if self.use_fia:
"""FIA will support multi-bs in the later version of CANN"""
q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
attn_output = torch.empty(
(q.size(0), layer.tp_q_head_num, layer.v_head_dim),
device=q.device,
dtype=q.dtype,
)
q_len_offset = 0
for q_len in forward_batch.extend_seq_lens_cpu:
attn_output[q_len_offset : q_len_offset + q_len] = (
torch.ops.npu.npu_fused_infer_attention_score(
q[None, q_len_offset : q_len_offset + q_len],
k[None, q_len_offset : q_len_offset + q_len],
v[None, q_len_offset : q_len_offset + q_len],
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSND", # todo, TND not supports q_heads!=k_heads
atten_mask=self.fia_mask.unsqueeze(0),
sparse_mode=3,
scale=layer.scaling,
next_tokens=0,
)[0]
)
q_len_offset += q_len
attn_output = attn_output.view(
-1, layer.tp_q_head_num * layer.v_head_dim
)
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=k_cache,
value_cache=v_cache,
mask=self.mask,
block_table=self.forward_metadata.block_tables,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
context_lens=self.forward_metadata.seq_lens_cpu_int,
scale_value=layer.scaling,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
out=output,
)
return output
else:
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
causal = True
if (
layer.is_cross_attention
or layer.attn_type == AttentionType.ENCODER_ONLY
):
causal = False
self.native_attn._run_sdpa_forward_extend(
q_,
o_,
k_cache.view(
-1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
),
v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_prefix_lens,
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=causal,
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
attn_output = torch.empty(
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=k_cache,
value_cache=v_cache,
mask=self.mask,
block_table=self.forward_metadata.block_tables,
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
context_lens=self.forward_metadata.seq_lens_cpu_int,
scale_value=layer.scaling,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
out=attn_output,
)
else:
assert (
layer.qk_head_dim != layer.v_head_dim
), "FIA only supports qk_head_dim != v_head_dim"
q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
v,
query_rope=q_rope,
key_rope=k_rope,
num_heads=layer.tp_q_head_num,
input_layout="TND",
atten_mask=self.fia_mask,
sparse_mode=3,
actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
scale=layer.scaling,
next_tokens=0,
)
return o
return attn_output
def forward_decode(
self,
......@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend):
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
save_kv_cache: bool = False,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
num_tokens = q.shape[0]
if self.graph_mode:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
......@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend):
layer.layer_id
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
num_tokens = query.shape[0]
workspace = (
torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query,
......@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend):
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
)
)
output = torch.empty(
attn_output = torch.empty(
(num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
dtype=q.dtype,
device=q.device,
......@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend):
scale=layer.scaling,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
workspace=workspace,
out=[output, softmax_lse],
out=[attn_output, softmax_lse],
)
else:
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
layer.layer_id
)
if self.use_fia:
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q.view(
forward_batch.batch_size,
-1,
layer.tp_q_head_num,
layer.qk_head_dim,
),
k_cache.view(
-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
),
v_cache.view(
-1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
),
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSND",
atten_mask=None,
block_size=self.page_size,
block_table=self.forward_metadata.block_tables,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
scale=layer.scaling,
)
else:
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
attn_output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
)
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
num_tokens = query.shape[0]
output = torch.empty(
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
dtype=query.dtype,
device=query.device,
torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
value_cache=v_cache,
num_heads=layer.tp_q_head_num,
num_kv_heads=layer.tp_k_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
out=attn_output,
)
return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
torch_npu._npu_paged_attention(
query=query,
key_cache=k_cache,
value_cache=v_cache,
num_tokens = q.shape[0]
kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
if (self.graph_mode or self.use_fia) and (
layer.tp_q_head_num // layer.tp_k_head_num
) >= 8:
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
kv_c = kv_c.view(
-1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
)
k_pe = k_pe.view(
-1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim
)
q = q.view(
forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank
)
q_rope = q_rope.view(
forward_batch.batch_size,
-1,
layer.tp_q_head_num,
self.qk_rope_head_dim,
)
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
q,
kv_c,
kv_c,
query_rope=q_rope,
key_rope=k_pe,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="BSND",
atten_mask=None,
sparse_mode=0,
scale=layer.scaling,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
)
else:
assert (
self.graph_mode == False
) # _npu_paged_attention_mla not support graph mode
q = torch.cat([q, q_rope], dim=-1)
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1,
self.page_size,
layer.tp_k_head_num,
self.kv_lora_rank + self.qk_rope_head_dim,
)
attn_output = torch.empty(
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=layer.tp_k_head_num,
num_heads=layer.tp_q_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
out=output,
mla_vheadsize=self.kv_lora_rank,
out=attn_output,
)
return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
else:
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
num_tokens = query.shape[0]
kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
)
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1,
self.page_size,
layer.tp_k_head_num,
self.kv_lora_rank + self.qk_rope_head_dim,
)
attn_output = torch.empty(
[num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
torch_npu._npu_paged_attention_mla(
query=query,
key_cache=kv_c_and_k_pe_cache,
num_kv_heads=layer.tp_k_head_num,
num_heads=layer.tp_q_head_num,
scale_value=layer.scaling,
block_table=self.forward_metadata.block_tables,
context_lens=self.forward_metadata.seq_lens_cpu_int,
mla_vheadsize=self.kv_lora_rank,
out=attn_output,
)
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
......@@ -304,7 +304,7 @@ class TopK(CustomOp):
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256 and self.topk_config.renormalize is False:
if global_num_experts == 256 and self.topk_config.renormalize is True:
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32)
......
......@@ -36,12 +36,15 @@ import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
_is_cuda = is_cuda()
_is_npu = is_npu()
if _is_npu:
import torch_npu
class ReqToTokenPool:
......@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
import torch_npu
torch_npu._npu_reshape_and_cache(
key=cache_k,
value=cache_v,
......@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = torch.zeros(
self.k_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
self.kv_lora_rank + self.qk_rope_head_dim,
self.kv_lora_rank,
),
dtype=self.store_dtype,
device=self.device,
)
self.v_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
self.qk_rope_head_dim,
),
dtype=self.store_dtype,
device=self.device,
......@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
)
self.mem_usage = kv_size / GB
def get_kv_size_bytes(self):
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
kv_size_bytes = 0
for k_cache in self.k_buffer:
kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
for v_cache in self.v_buffer:
kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
return kv_size_bytes
def get_kv_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return (
self.k_buffer[layer_id - self.start_layer],
self.v_buffer[layer_id - self.start_layer],
)
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
self.v_buffer[i].data_ptr() for i in range(self.layer_num)
]
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i].nbytes for i in range(self.layer_num)
]
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
......@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
import torch_npu
if cache_v is None:
cache_k, cache_v = cache_k.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
torch_npu._npu_reshape_and_cache_siso(
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
torch_npu.npu_scatter_nd_update_(
self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
loc.view(-1, 1),
cache_k.view(-1, 1, self.kv_lora_rank),
)
torch_npu.npu_scatter_nd_update_(
self.v_buffer[layer_id - self.start_layer].view(
-1, 1, self.qk_rope_head_dim
),
slot_indices=loc,
loc.view(-1, 1),
cache_v.view(-1, 1, self.qk_rope_head_dim),
)
......
......@@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module):
self.current_attention_backend = attention_backend
if attention_backend == "ascend":
return AttnForwardMethod.MLA
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
elif (
attention_backend == "flashinfer"
or attention_backend == "fa3"
......@@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module):
or self.current_attention_backend == "flashinfer"
or self.current_attention_backend == "cutlass_mla"
or self.current_attention_backend == "trtllm_mla"
or self.current_attention_backend == "ascend"
):
extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch):
......
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": {
"accuracy": 0.34,
"latency": 1000,
"output_throughput": 6,
},
}
class TestAscendMlaW8A8Int8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--quantization",
"w8a8_int8",
"--tp-size",
2,
"--disable-radix-cache",
]
def test_a_gsm8k(self):
os.environ["ASCEND_USE_FIA"] = "true"
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
......@@ -40,6 +40,7 @@ class TestAscendMlaW8A8Int8(CustomTestCase):
"w8a8_int8",
"--tp-size",
4,
"--disable-radix-cache",
]
def test_a_gsm8k(self):
......
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.85,
"latency": 180,
"output_throughput": 20,
},
}
class TestAscendTp2Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--tp-size",
2,
"--disable-radix-cache",
]
def test_a_gsm8k(self):
os.environ["ASCEND_USE_FIA"] = "true"
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
......@@ -275,6 +275,8 @@ suite_ascend = {
"per-commit-2-ascend-npu": [
TestFile("ascend/test_ascend_tp2_bf16.py", 400),
TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400),
TestFile("ascend/test_ascend_tp2_fia_bf16.py", 400),
TestFile("ascend/test_ascend_mla_fia_w8a8int8.py", 400),
],
"per-commit-4-ascend-npu": [
TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment