Unverified Commit ce6b17c0 authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

[Feature] Support DeepSeek MTP on NPU (#11897)


Co-authored-by: default avatarliupeng374 <liupeng374@huawei.com>
parent cafebef1
......@@ -65,7 +65,7 @@ jobs:
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on: linux-arm64-npu-2
strategy:
fail-fast: false
fail-fast: true
matrix:
part: [0, 1, 2]
container:
......@@ -144,6 +144,10 @@ jobs:
per-commit-16-ascend-a3:
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on: linux-aarch64-a3-16
strategy:
fail-fast: true
matrix:
part: [0, 1]
container:
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11
steps:
......@@ -177,4 +181,4 @@ jobs:
run: |
export PATH="/usr/local/Ascend/8.3.RC1/compiler/bishengir/bin:${PATH}"
cd test/srt
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 3600
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2
......@@ -59,6 +59,19 @@ class AscendAttnBackend(AttentionBackend):
)
self.mask_len = max_seq_len
def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.
Typically, these are tree mask and position buffers.
"""
return [None, None]
def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata = None
......@@ -87,15 +100,22 @@ class AscendAttnBackend(AttentionBackend):
device=model_runner.device,
)
)
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu()
self.mtp_mask = ~self.mtp_mask
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
tp_size = get_attention_tp_size()
self.forward_metadata = ForwardMetadata()
seq_lens_max = forward_batch.seq_lens.max()
if forward_batch.forward_mode.is_target_verify():
seq_lens_max += self.speculative_num_draft_tokens
self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
forward_batch.req_pool_indices, :seq_lens_max
][:, :: self.page_size]
// self.page_size
)
......@@ -104,16 +124,23 @@ class AscendAttnBackend(AttentionBackend):
forward_batch.extend_seq_lens.cpu().int()
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
if (
not forward_batch.forward_mode.is_draft_extend_v2()
and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_target_verify()
):
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
if forward_batch.forward_mode.is_target_verify():
self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens
self.graph_mode = False
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.graph_metadata = {
"block_tables": torch.empty(
(max_bs, self.max_context_len // self.page_size),
(max_bs, (self.max_context_len + self.page_size - 1) // self.page_size),
dtype=torch.int32,
device=self.device,
),
......@@ -156,6 +183,8 @@ class AscendAttnBackend(AttentionBackend):
):
metadata = self.graph_metadata[bs]
max_len = seq_lens_cpu[:bs].max().item()
if forward_mode.is_target_verify():
max_len += self.speculative_num_draft_tokens
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.block_tables[:bs, :max_seq_pages].copy_(
......@@ -257,6 +286,25 @@ class AscendAttnBackend(AttentionBackend):
k_rope,
topk_indices,
)
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
or forward_batch.forward_mode.is_draft_extend_v2()
):
if is_mla_preprocess_enabled():
save_kv_cache = False
return self.forward_mtp(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope=q_rope,
k_rope=k_rope,
)
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
......@@ -393,6 +441,118 @@ class AscendAttnBackend(AttentionBackend):
)
return attn_output
def forward_mtp(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if save_kv_cache:
if self.use_mla:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_rope_cache = k_rope.view(
-1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
)
c_kv_cache = c_kv.view(
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
)
q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank)
q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)
if not self.graph_mode:
num_token_padding = q.shape[0]
q_nope = q_nope[: forward_batch.num_token_non_padded_cpu]
q_rope = q_rope[: forward_batch.num_token_non_padded_cpu]
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_list
else:
actual_seq_lengths_kv = (
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
)
if forward_batch.forward_mode.is_draft_extend():
actual_seq_lengths = (
np.array(forward_batch.extend_seq_lens_cpu).cumsum().tolist()
)
else:
actual_seq_lengths = np.arange(
self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens + q_nope.shape[0],
self.speculative_num_draft_tokens,
)
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope,
c_kv_cache,
c_kv_cache,
query_rope=q_rope,
key_rope=k_rope_cache,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="TND",
scale=layer.scaling,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
sparse_mode=3,
atten_mask=self.mtp_mask,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
)
attn_output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
c_kv_cache,
c_kv_cache,
query_rope=q_rope,
key_rope=k_rope_cache,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="TND",
scale=layer.scaling,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
sparse_mode=3,
atten_mask=self.mtp_mask,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
workspace=workspace,
out=[attn_output, softmax_lse],
)
attn_output = attn_output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
if (
not self.graph_mode
and forward_batch.num_token_non_padded_cpu != num_token_padding
):
attn_output = torch.cat(
[
attn_output,
attn_output.new_zeros(
num_token_padding - attn_output.shape[0], *attn_output.shape[1:]
),
],
dim=0,
)
return attn_output
def forward_decode_graph(
self,
q: torch.Tensor,
......@@ -690,3 +850,71 @@ class AscendAttnBackend(AttentionBackend):
out=attn_output,
)
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
class AscendAttnMultiStepDraftBackend:
"""
Wrap multiple Ascend attention backends as one for multiple consecutive
draft decoding steps
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.attn_backends = []
for _ in range(self.speculative_num_steps):
self.attn_backends.append(AscendAttnBackend(model_runner))
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
assert forward_batch.spec_info is not None
for i in range(self.speculative_num_steps - 1):
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs, max_num_tokens):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)
self.common_template(forward_batch, call_fn)
......@@ -77,6 +77,9 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.common import is_npu
_is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
......@@ -1050,7 +1053,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
has_grammar: bool = False
# Device
device: str = "cuda"
if not _is_npu:
device: str = "cuda"
else:
device: str = "npu"
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
......
......@@ -75,9 +75,13 @@ class NPUGraphRunner(CudaGraphRunner):
# Replay
if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
self.bs - self.raw_bs
)
if forward_batch.forward_mode.is_target_verify():
seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs
seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)
else:
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
self.bs - self.raw_bs
)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
thread.start()
self.graphs[self.bs].replay()
......
......@@ -38,12 +38,13 @@ from sglang.srt.models.deepseek_v2 import (
enable_nextn_moe_bf16_cast_to_fp8,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_npu
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_npu = is_npu()
class DeepseekModelNextN(nn.Module):
......@@ -85,13 +86,21 @@ class DeepseekModelNextN(nn.Module):
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
layer_name = "decoder"
if _is_npu and (
get_global_server_args().speculative_draft_model_path
== get_global_server_args().model_path
):
layer_name = "layers." + str(config.num_hidden_layers)
self.decoder = DeepseekV2DecoderLayer(
config,
0,
quant_config=quant_config,
moe_quant_config=moe_quant_config,
is_nextn=True,
prefix=add_prefix("decoder", prefix),
prefix=add_prefix(layer_name, prefix),
alt_stream=self.alt_stream,
)
......
......@@ -290,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch):
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_draft_extend_v2()
):
if hasattr(attn, "indexer"):
return AttnForwardMethod.NPU_MLA_SPARSE
......@@ -3753,8 +3754,12 @@ class DeepseekV2ForCausalLM(nn.Module):
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
if not _is_npu:
torch.cuda.empty_cache()
torch.cuda.synchronize()
else:
torch.npu.empty_cache()
torch.npu.synchronize()
@classmethod
def get_model_config_for_expert_location(cls, config):
......
......@@ -49,6 +49,7 @@ class DraftBackendFactory:
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend,
"ascend": self._create_ascend_decode_backend,
}
return self._create_backend(
......@@ -72,6 +73,7 @@ class DraftBackendFactory:
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend,
"ascend": self._create_ascend_prefill_backend,
}
backend_name = (
"decode_attention_backend"
......@@ -173,6 +175,15 @@ class DraftBackendFactory:
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_ascend_decode_backend(self):
from sglang.srt.layers.attention.ascend_backend import (
AscendAttnMultiStepDraftBackend,
)
return AscendAttnMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_flashinfer_prefill_backend(self):
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
......@@ -219,6 +230,11 @@ class DraftBackendFactory:
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
def _create_ascend_prefill_backend(self):
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
return AscendAttnBackend(self.draft_model_runner)
def _create_flashmla_prefill_backend(self):
logger.warning(
"flashmla prefill backend is not yet supported for draft extend."
......
......@@ -24,12 +24,13 @@ from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftInputV2Mixin,
EagleVerifyInputV2Mixin,
)
from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN,
TREE_SPEC_KERNEL_AVAILABLE,
align_evict_mask_to_page_size,
assign_req_to_token_pool,
assign_req_to_token_pool_func,
create_accept_length_filter,
create_extend_after_decode_spec_info,
filter_finished_cache_loc_kernel,
......@@ -37,17 +38,16 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc,
get_target_cache_loc,
)
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
from sglang.srt.utils import is_cuda, is_npu, next_power_of_2
_is_npu = is_npu()
if is_cuda():
from sgl_kernel import (
top_k_renorm_prob,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
elif is_hip():
from sgl_kernel import verify_tree_greedy
logger = logging.getLogger(__name__)
......@@ -77,18 +77,22 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
@classmethod
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
if not _is_npu:
device = "cuda"
else:
device = "npu"
return cls(
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
draft_token=torch.empty((0,), dtype=torch.long, device=device),
custom_mask=torch.full((0,), True, dtype=torch.bool, device=device),
positions=torch.empty((0,), dtype=torch.int64, device=device),
retrive_index=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
(0, num_verify_tokens), -1, dtype=torch.long, device=device
),
retrive_next_token=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
(0, num_verify_tokens), -1, dtype=torch.long, device=device
),
retrive_next_sibling=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
(0, num_verify_tokens), -1, dtype=torch.long, device=device
),
retrive_cum_len=None,
topk=topk,
......@@ -134,14 +138,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
self.last_loc = last_loc
bs = batch.batch_size()
assign_req_to_token_pool[(bs,)](
assign_req_to_token_pool_func(
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
end_offset,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
bs,
)
def generate_attn_arg_prefill(
......@@ -151,16 +154,17 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
):
device = req_pool_indices.device
batch_size = len(req_pool_indices)
qo_indptr = torch.arange(
0,
(1 + batch_size) * self.draft_token_num,
step=self.draft_token_num,
dtype=torch.int32,
device="cuda",
device=device,
)
cum_kv_seq_len = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda"
(batch_size + 1,), dtype=torch.int32, device=device
)
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
......@@ -169,7 +173,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
kv_indices = torch.empty(
paged_kernel_lens_sum + self.draft_token_num * batch_size,
dtype=torch.int32,
device="cuda",
device=device,
)
create_flashinfer_kv_indices_triton[(batch_size,)](
req_to_token,
......@@ -226,11 +230,11 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1
predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda")
predict = torch.empty(predict_shape, dtype=torch.int32, device=batch.device)
accept_index = torch.full(
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda"
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device=batch.device
)
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
accept_length = torch.empty((bs,), dtype=torch.int32, device=batch.device)
if bs != len(sampling_info):
sampling_info = copy.deepcopy(sampling_info)
......@@ -254,7 +258,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
linear_penalty = torch.zeros(
(bs, logits_output.next_token_logits.shape[1]),
dtype=torch.float32,
device="cuda",
device=batch.device,
)
sampling_info.apply_logits_bias(linear_penalty)
logits_output.next_token_logits.add_(
......@@ -276,11 +280,10 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
"Falling back to greedy verification."
)
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE or _is_npu:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num)
verify_tree_greedy(
predict, accept_index, accept_length = verify_tree_greedy_func(
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
......@@ -289,7 +292,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict,
topk=self.topk,
)
else:
# apply temperature and get target probs
expanded_temperature = torch.repeat_interleave(
......@@ -315,14 +320,16 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
draft_probs = torch.zeros(
target_probs.shape, dtype=torch.float32, device="cuda"
target_probs.shape, dtype=torch.float32, device=batch.device
)
# coins for rejection sampling
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda")
coins = torch.rand_like(
candidates, dtype=torch.float32, device=batch.device
)
# coins for final sampling
coins_for_final_sampling = torch.rand(
(bs,), dtype=torch.float32, device="cuda"
(bs,), dtype=torch.float32, device=batch.device
)
tree_speculative_sampling_target_only(
predicts=predict, # mutable
......@@ -468,14 +475,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
if not has_finished:
if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[accept_index]
assign_req_to_token_pool[(bs,)](
assign_req_to_token_pool_func(
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
bs,
)
else:
batch.out_cache_loc = tgt_cache_loc
......@@ -501,14 +507,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
)
else:
if page_size == 1 or self.topk == 1:
assign_req_to_token_pool[(bs,)](
assign_req_to_token_pool_func(
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + accept_length + 1,
batch.out_cache_loc[accept_index],
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
bs,
)
batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
......@@ -695,17 +700,18 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
):
device = req_pool_indices.device
bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
if paged_kernel_lens_sum is None:
paged_kernel_lens_sum = cum_kv_seq_len[-1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
paged_kernel_lens_sum, dtype=torch.int32, device=device
)
create_flashinfer_kv_indices_triton[(bs,)](
......
......@@ -23,11 +23,16 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN,
generate_simulated_accept_index,
)
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, is_npu, next_power_of_2
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker
......@@ -41,11 +46,8 @@ if is_cuda():
top_k_renorm_prob,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
elif is_hip():
from sgl_kernel import verify_tree_greedy
@triton.jit
......@@ -78,7 +80,7 @@ def assign_draft_cache_locs_page_size_1(
@dataclass
class EagleDraftInputV2Mixin:
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func
bs = batch.batch_size()
......@@ -112,15 +114,15 @@ class EagleDraftInputV2Mixin:
extend_num_tokens,
)
assign_req_to_token_pool[(bs,)](
assign_req_to_token_pool_func(
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
bs,
)
self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional
......@@ -199,22 +201,16 @@ class EagleVerifyInputV2Mixin:
bs = len(batch.req_pool_indices)
batch.input_ids = self.draft_token
device = batch.input_ids.device
batch.out_cache_loc = torch.empty(
(bs * self.draft_token_num,),
dtype=torch.int64,
batch.out_cache_loc = assign_extend_cache_locs_func(
req_pool_indices=batch.req_pool_indices,
req_to_token=req_to_token_pool.req_to_token,
start_offset=batch.seq_lens,
end_offset=batch.seq_lens + self.draft_token_num,
batch_size=bs,
draft_token_num=self.draft_token_num,
device=device,
)
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
# Get a forward batch
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.capture_hidden_mode = CaptureHiddenMode.FULL
......@@ -258,11 +254,10 @@ class EagleVerifyInputV2Mixin:
accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
# Sample tokens
if sampling_info.is_all_greedy:
if sampling_info.is_all_greedy or _is_npu:
target_predict = torch.argmax(next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num)
verify_tree_greedy(
predict, accept_index, accept_length = verify_tree_greedy_func(
predicts=predict, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable
......@@ -271,6 +266,7 @@ class EagleVerifyInputV2Mixin:
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict,
topk=self.topk,
)
else:
# Apply temperature and get target probs
......@@ -338,7 +334,7 @@ class EagleVerifyInputV2Mixin:
return predict, accept_length, accept_index
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, disable=_is_npu)
def select_top_k_tokens_tmp(
i: int,
topk_p: torch.Tensor,
......@@ -456,3 +452,50 @@ def assign_extend_cache_locs(
tl.store(out_cache_ptr + save_offset, data, mask=mask)
load_offset += BLOCK_SIZE
save_offset += BLOCK_SIZE
def assign_extend_cache_locs_func(
req_pool_indices: torch.Tensor,
req_to_token: torch.Tensor,
start_offset: torch.Tensor,
end_offset: torch.Tensor,
batch_size: int,
draft_token_num: int,
device,
) -> torch.Tensor:
if _is_cuda or _is_hip:
out_cache_loc = torch.empty(
(batch_size * draft_token_num,),
dtype=torch.int64,
device=device,
)
assign_extend_cache_locs[(batch_size,)](
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
req_to_token.shape[1],
next_power_of_2(batch_size),
)
return out_cache_loc
elif _is_npu:
import sgl_kernel_npu # noqa: F401
out_cache_loc = torch.empty(
(batch_size * draft_token_num,),
dtype=torch.int32,
device=device,
)
torch.ops.npu.cache_loc_update(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
)
out_cache_loc = out_cache_loc.to(dtype=torch.int64)
return out_cache_loc
......@@ -4,14 +4,128 @@ from typing import List, Optional
import torch
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import is_cuda, is_hip, is_npu
if is_cuda() or is_hip():
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
if _is_cuda or _is_hip:
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
def build_tree_efficient_native(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
topk: int,
draft_token_num: int,
tree_mask_mode: int,
bs: int,
):
# Generate batch and token index ranges
bs_range = torch.arange(bs, device=tree_mask.device).view(-1, 1)
draft_token_num_range = torch.arange(draft_token_num, device=tree_mask.device)
# Optimized common case for performance.
if draft_token_num == 2 and topk == 1 and tree_mask_mode == TreeMaskMode.FULL_MASK:
positions = verified_seq_len.repeat_interleave(draft_token_num)
positions = (positions.view(bs, -1) + draft_token_num_range).view(-1)
retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
retrive_next_token[:, 0] = 1
retrive_next_token[:, 1] = -1
return (
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
tree_mask,
)
# Precompute sequence tree indices
draft_token_num_range1 = torch.arange(draft_token_num - 1, device=tree_mask.device)
cum_seq_len = torch.cumsum(verified_seq_len * draft_token_num, dim=0)
cum_seq_len = torch.cat((torch.tensor([0], device=tree_mask.device), cum_seq_len))
cum_seq_len = cum_seq_len[:-1]
seq_tree_idx = (
draft_token_num * draft_token_num * torch.arange(bs, device=tree_mask.device)
+ cum_seq_len
)
# Batch processing for tree mask
if tree_mask_mode == TreeMaskMode.FULL_MASK:
token_tree_base = (
seq_tree_idx.view(-1, 1)
+ (verified_seq_len.view(-1, 1) + draft_token_num) * draft_token_num_range
)
token_tree_indices = token_tree_base + verified_seq_len.view(-1, 1) + 1
else:
token_tree_indices = (
bs_range * draft_token_num**2 + draft_token_num_range * draft_token_num + 1
)
tree_mask[token_tree_indices.flatten() - 1] = True
indices = token_tree_indices.unsqueeze(-1) + draft_token_num_range1.view(1, 1, -1)
tree_mask[indices.view(-1)] = False
positions = verified_seq_len.repeat_interleave(draft_token_num)
parent_tb_indices = selected_index // topk
retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
tree_mask[token_tree_indices.view(-1, 1) + draft_token_num_range1] = True
for bid in range(bs):
for tid in range(draft_token_num):
position = 0
if tid == 0:
# Process root node
for i in range(draft_token_num - 1, 0, -1):
parent_position = 0
parent_tb_idx = parent_tb_indices[bid][i - 1]
if parent_tb_idx > 0:
parent_token_idx = parent_list[bid][parent_tb_idx]
loop_num = draft_token_num - parent_position
for _ in range(loop_num):
if selected_index[bid][parent_position] == parent_token_idx:
parent_position += 1
break
parent_position += 1
if parent_position == draft_token_num:
continue
if retrive_next_token[bid][parent_position] != -1:
retrive_next_sibling[bid][i] = retrive_next_token[bid][
parent_position
]
retrive_next_token[bid][parent_position] = i
else:
# Process no-root nodes
cur_position = tid - 1
while True:
position += 1
if cur_position >= draft_token_num:
tree_mask[token_tree_indices + cur_position] = True
parent_tb_idx = selected_index[bid][cur_position] // topk
else:
parent_tb_idx = parent_tb_indices[bid][cur_position]
if parent_tb_idx == 0:
break
token_idx = parent_list[bid][parent_tb_idx]
cur_position = 0
for _ in range(draft_token_num):
if selected_index[bid][cur_position] == token_idx:
break
cur_position += 1
positions[bid * draft_token_num + tid] += position
return positions, retrive_index, retrive_next_token, retrive_next_sibling, tree_mask
def organize_draft_results(
score_list: List[torch.Tensor],
token_list: List[torch.Tensor],
......@@ -114,20 +228,41 @@ def build_tree_kernel_efficient(
(bs * num_verify_tokens,), device=device, dtype=torch.long
)
sgl_build_tree_kernel_efficient(
parent_list,
top_scores_index,
seq_lens,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
spec_steps,
num_verify_tokens,
tree_mask_mode,
)
if _is_npu:
(
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
tree_mask,
) = build_tree_efficient_native(
parent_list,
top_scores_index,
seq_lens,
tree_mask,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
num_verify_tokens,
tree_mask_mode,
bs,
)
else:
sgl_build_tree_kernel_efficient(
parent_list,
top_scores_index,
seq_lens,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
spec_steps,
num_verify_tokens,
tree_mask_mode,
)
return (
tree_mask,
positions,
......@@ -136,3 +271,113 @@ def build_tree_kernel_efficient(
retrive_next_sibling,
draft_tokens,
)
def verify_tree_greedy_native(
predicts: torch.Tensor,
accept_index: torch.Tensor,
accept_token_num: torch.Tensor,
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor,
topk: int = -1,
):
batch_size, num_draft_tokens = candidates.shape
# Optimized common case for performance.
if num_draft_tokens == 2 and accept_index.shape[1] == 2 and topk == 1:
comparison_result = candidates[:, 1] == target_predict[:, 0]
predicts = target_predict.flatten()
accept_index = torch.arange(
0, num_draft_tokens * batch_size, device=candidates.device, dtype=torch.long
).reshape(batch_size, num_draft_tokens)
comparison_result = comparison_result.to(torch.int64)
accept_index_mask = accept_index[:, 1] * comparison_result
accept_index[:, 1] = accept_index_mask - (1 - comparison_result)
accept_token_num = comparison_result.int()
return predicts, accept_index, accept_token_num
# BFS
for bx in range(batch_size):
cur_candidates = candidates[bx]
cur_retrive_index = retrive_index[bx]
cur_next_token = retrive_next_token[bx]
cur_next_sibling = retrive_next_sibling[bx]
cur_target = target_predict[bx]
last_accepted_idx = cur_retrive_index[0]
accept_index[bx, 0] = last_accepted_idx
num_accepted = 0
cur_node = 0
for _ in range(1, num_draft_tokens):
cur_node = cur_next_token[cur_node]
found = False
while cur_node != -1:
draft_idx = cur_retrive_index[cur_node]
draft_token = cur_candidates[cur_node]
target_token = cur_target[last_accepted_idx - num_draft_tokens * bx]
if draft_token == target_token:
predicts[last_accepted_idx] = target_token
num_accepted += 1
accept_index[bx, num_accepted] = draft_idx
last_accepted_idx = draft_idx
found = True
break
else:
cur_node = cur_next_sibling[cur_node]
if not found:
break
accept_token_num[bx] = num_accepted
predicts[last_accepted_idx] = cur_target[
last_accepted_idx - num_draft_tokens * bx
]
return predicts, accept_index, accept_token_num
def verify_tree_greedy_func(
predicts: torch.Tensor,
accept_index: torch.Tensor,
accept_token_num: torch.Tensor,
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor,
topk: int = -1,
):
if _is_cuda or _is_hip:
from sgl_kernel import verify_tree_greedy
verify_tree_greedy(
predicts=predicts, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_token_num, # mutable
candidates=candidates,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
target_predict=target_predict,
)
elif _is_npu:
predicts, accept_index, accept_token_num = verify_tree_greedy_native(
predicts=predicts, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_token_num, # mutable
candidates=candidates,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
target_predict=target_predict,
topk=topk,
)
return predicts, accept_index, accept_token_num
......@@ -53,9 +53,12 @@ from sglang.srt.utils import (
get_available_gpu_memory,
get_bool_env_var,
is_cuda,
is_npu,
next_power_of_2,
)
_is_npu = is_npu()
if is_cuda():
from sgl_kernel import segment_packbits # noqa: F401
......@@ -205,7 +208,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph:
if self.server_args.disable_cuda_graph or _is_npu:
return
# Capture draft
......@@ -945,7 +948,7 @@ class EAGLEWorker(TpModelWorker):
draft_input.hidden_states = logits_output.hidden_states
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, disable=_is_npu)
def get_last_loc_large_page_size_top_k_1(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
......
......@@ -4,7 +4,6 @@ import time
from typing import List, Optional, Tuple
import torch
from torch.cuda import Stream as CudaStream
from sglang.srt.environ import envs
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
......@@ -38,18 +37,21 @@ from sglang.srt.utils.common import (
empty_context,
fast_topk,
get_available_gpu_memory,
is_npu,
next_power_of_2,
)
_is_npu = is_npu()
logger = logging.getLogger(__name__)
def _get_plan_stream(
device: str,
) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]:
) -> Tuple[any, contextlib.AbstractContextManager]:
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
plan_stream: CudaStream = torch.get_device_module(device).Stream()
plan_stream_ctx = torch.cuda.stream(plan_stream)
plan_stream = torch.get_device_module(device).Stream()
plan_stream_ctx = torch.get_device_module(device).stream(plan_stream)
return plan_stream, plan_stream_ctx
else:
return None, contextlib.nullcontext()
......@@ -206,7 +208,7 @@ class EagleDraftWorker(BaseDraftWorker):
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph:
if self.server_args.disable_cuda_graph or _is_npu:
return
# Capture draft
......@@ -456,7 +458,9 @@ class EagleDraftWorker(BaseDraftWorker):
)
if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream)
torch.get_device_module(self.device).current_stream().wait_stream(
self.plan_stream
)
# Run draft extend batch in the main compute stream
draft_logits_output = self.draft_runner.model.forward(
......@@ -577,7 +581,9 @@ class EAGLEWorkerV2(BaseSpecWorker):
# Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
batch.seq_lens.record_stream(torch.cuda.current_stream())
batch.seq_lens.record_stream(
torch.get_device_module(self.device).current_stream()
)
# Parse args
verify_input: EagleVerifyInput = batch.spec_info
......@@ -596,7 +602,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
# Correct some buffers due to the overlap plan
if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream)
torch.get_device_module().current_stream().wait_stream(self.plan_stream)
# Some values such as custom_mask and position depend on the output of draft,
# so the previous plan step used the wrong values. Here, we need to run the related
......@@ -628,7 +634,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
accept_index,
) = verify_input.sample(batch, logits_output)
new_seq_lens = batch.seq_lens + accept_length
verify_done = torch.cuda.Event()
verify_done = torch.get_device_module(self.device).Event()
verify_done.record()
all_verified_id = predict[accept_index]
......
......@@ -19,16 +19,22 @@ from sglang.srt.distributed.parallel_state import (
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import is_cuda, is_hip, is_npu, next_power_of_2
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_info import EagleVerifyInput
if is_cuda():
if _is_cuda:
from sgl_kernel import fast_topk
elif is_hip():
elif _is_hip:
from sgl_kernel import fast_topk
else:
from sglang.srt.utils.common import fast_topk
logger = logging.getLogger(__name__)
......@@ -39,7 +45,7 @@ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE = is_cuda() # This kernel is only available for CUDA now
TREE_SPEC_KERNEL_AVAILABLE = _is_cuda # This kernel is only available for CUDA now
@triton.jit
......@@ -103,6 +109,36 @@ def assign_req_to_token_pool(
load_offset += BLOCK_SIZE
def assign_req_to_token_pool_func(
req_pool_indices: torch.Tensor,
req_to_token: torch.Tensor,
start_offset: torch.Tensor,
end_offset: torch.Tensor,
out_cache_loc: torch.Tensor,
batch_size: int,
):
if _is_cuda or _is_hip:
assign_req_to_token_pool[(batch_size,)](
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
req_to_token.shape[1],
next_power_of_2(batch_size),
)
elif _is_npu:
import sgl_kernel_npu # noqa: F401
torch.ops.npu.cache_loc_assign(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
)
@triton.jit
def assign_draft_cache_locs(
req_pool_indices,
......@@ -331,7 +367,7 @@ def get_target_cache_loc(
)
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, disable=_is_npu)
def get_src_tgt_cache_loc(
seq_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
......@@ -381,7 +417,7 @@ def filter_finished_cache_loc_kernel(
)
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, disable=_is_npu)
def create_accept_length_filter(
accept_length: torch.Tensor,
unfinished_index_device: torch.Tensor,
......@@ -395,7 +431,7 @@ def create_accept_length_filter(
return accept_length_filter
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, disable=_is_npu)
def select_top_k_tokens(
i: int,
topk_p: torch.Tensor,
......@@ -413,7 +449,7 @@ def select_top_k_tokens(
tree_info = (
topk_p.unsqueeze(1), # shape: (b, 1, topk)
topk_index, # shape: (b, topk)
torch.arange(-1, topk, dtype=torch.long, device="cuda")
torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device)
.unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
)
......
......@@ -3106,12 +3106,16 @@ def apply_module_patch(target_module, target_function, wrappers):
setattr(original_module, target_function, candidate)
for key, value in sys.modules.copy().items():
if (
target_function is not None
and hasattr(value, target_function)
and id(getattr(value, target_function)) == original_function_id
):
setattr(value, target_function, candidate)
try:
if (
target_function is not None
and hasattr(value, target_function)
and id(getattr(value, target_function)) == original_function_id
):
setattr(value, target_function, candidate)
except ImportError as e:
# Ignore some modules reporting ImportError when calling hasattr
logger.warning(f"Ignore {value} reports ImportError with:\n{str(e)}")
def parse_module_path(module_path, function_name, create_dummy):
......
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-R1-0528-W8A8": {
"accuracy": 0.95,
"latency": 1000,
"output_throughput": 6,
},
}
class TestAscendDeepSeekMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--attention-backend",
"ascend",
"--quantization",
"w8a8_int8",
"--mem-fraction-static",
0.8,
"--disable-radix-cache",
"--chunked-prefill-size",
32768,
"--tp-size",
16,
"--speculative-algorithm",
"NEXTN",
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
]
cls.extra_envs = {
"SGLANG_NPU_USE_MLAPO": "1",
}
os.environ.update(cls.extra_envs)
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=1500,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
......@@ -359,6 +359,7 @@ suite_ascend = {
],
"per-commit-16-ascend-a3": [
TestFile("ascend/test_ascend_deepep.py", 400),
TestFile("ascend/test_ascend_deepseek_mtp.py", 400),
],
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment