Unverified Commit ce6b17c0 authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

[Feature] Support DeepSeek MTP on NPU (#11897)


Co-authored-by: default avatarliupeng374 <liupeng374@huawei.com>
parent cafebef1
...@@ -65,7 +65,7 @@ jobs: ...@@ -65,7 +65,7 @@ jobs:
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci') if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on: linux-arm64-npu-2 runs-on: linux-arm64-npu-2
strategy: strategy:
fail-fast: false fail-fast: true
matrix: matrix:
part: [0, 1, 2] part: [0, 1, 2]
container: container:
...@@ -144,6 +144,10 @@ jobs: ...@@ -144,6 +144,10 @@ jobs:
per-commit-16-ascend-a3: per-commit-16-ascend-a3:
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci') if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
runs-on: linux-aarch64-a3-16 runs-on: linux-aarch64-a3-16
strategy:
fail-fast: true
matrix:
part: [0, 1]
container: container:
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11 image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-a3-ubuntu22.04-py3.11
steps: steps:
...@@ -177,4 +181,4 @@ jobs: ...@@ -177,4 +181,4 @@ jobs:
run: | run: |
export PATH="/usr/local/Ascend/8.3.RC1/compiler/bishengir/bin:${PATH}" export PATH="/usr/local/Ascend/8.3.RC1/compiler/bishengir/bin:${PATH}"
cd test/srt cd test/srt
python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 3600 python3 run_suite.py --suite per-commit-16-ascend-a3 --timeout-per-file 3600 --auto-partition-id ${{ matrix.part }} --auto-partition-size 2
...@@ -59,6 +59,19 @@ class AscendAttnBackend(AttentionBackend): ...@@ -59,6 +59,19 @@ class AscendAttnBackend(AttentionBackend):
) )
self.mask_len = max_seq_len self.mask_len = max_seq_len
def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.
Typically, these are tree mask and position buffers.
"""
return [None, None]
def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
):
pass
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
super().__init__() super().__init__()
self.forward_metadata = None self.forward_metadata = None
...@@ -87,15 +100,22 @@ class AscendAttnBackend(AttentionBackend): ...@@ -87,15 +100,22 @@ class AscendAttnBackend(AttentionBackend):
device=model_runner.device, device=model_runner.device,
) )
) )
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
)
self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu()
self.mtp_mask = ~self.mtp_mask
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
tp_size = get_attention_tp_size() tp_size = get_attention_tp_size()
self.forward_metadata = ForwardMetadata() self.forward_metadata = ForwardMetadata()
seq_lens_max = forward_batch.seq_lens.max()
if forward_batch.forward_mode.is_target_verify():
seq_lens_max += self.speculative_num_draft_tokens
self.forward_metadata.block_tables = ( self.forward_metadata.block_tables = (
forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : forward_batch.seq_lens.max() forward_batch.req_pool_indices, :seq_lens_max
][:, :: self.page_size] ][:, :: self.page_size]
// self.page_size // self.page_size
) )
...@@ -104,16 +124,23 @@ class AscendAttnBackend(AttentionBackend): ...@@ -104,16 +124,23 @@ class AscendAttnBackend(AttentionBackend):
forward_batch.extend_seq_lens.cpu().int() forward_batch.extend_seq_lens.cpu().int()
) )
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
if (
not forward_batch.forward_mode.is_draft_extend_v2()
and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_target_verify()
):
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu) seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
if forward_batch.forward_mode.is_target_verify():
self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens
self.graph_mode = False self.graph_mode = False
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.graph_metadata = { self.graph_metadata = {
"block_tables": torch.empty( "block_tables": torch.empty(
(max_bs, self.max_context_len // self.page_size), (max_bs, (self.max_context_len + self.page_size - 1) // self.page_size),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
), ),
...@@ -156,6 +183,8 @@ class AscendAttnBackend(AttentionBackend): ...@@ -156,6 +183,8 @@ class AscendAttnBackend(AttentionBackend):
): ):
metadata = self.graph_metadata[bs] metadata = self.graph_metadata[bs]
max_len = seq_lens_cpu[:bs].max().item() max_len = seq_lens_cpu[:bs].max().item()
if forward_mode.is_target_verify():
max_len += self.speculative_num_draft_tokens
max_seq_pages = (max_len + self.page_size - 1) // self.page_size max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.block_tables[:bs, :max_seq_pages].copy_( metadata.block_tables[:bs, :max_seq_pages].copy_(
...@@ -257,6 +286,25 @@ class AscendAttnBackend(AttentionBackend): ...@@ -257,6 +286,25 @@ class AscendAttnBackend(AttentionBackend):
k_rope, k_rope,
topk_indices, topk_indices,
) )
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
or forward_batch.forward_mode.is_draft_extend_v2()
):
if is_mla_preprocess_enabled():
save_kv_cache = False
return self.forward_mtp(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope=q_rope,
k_rope=k_rope,
)
if not self.use_mla: if not self.use_mla:
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
...@@ -393,6 +441,118 @@ class AscendAttnBackend(AttentionBackend): ...@@ -393,6 +441,118 @@ class AscendAttnBackend(AttentionBackend):
) )
return attn_output return attn_output
def forward_mtp(
self,
q,
k,
v,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if save_kv_cache:
if self.use_mla:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_rope_cache = k_rope.view(
-1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
)
c_kv_cache = c_kv.view(
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
)
q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank)
q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)
if not self.graph_mode:
num_token_padding = q.shape[0]
q_nope = q_nope[: forward_batch.num_token_non_padded_cpu]
q_rope = q_rope[: forward_batch.num_token_non_padded_cpu]
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_list
else:
actual_seq_lengths_kv = (
self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
)
if forward_batch.forward_mode.is_draft_extend():
actual_seq_lengths = (
np.array(forward_batch.extend_seq_lens_cpu).cumsum().tolist()
)
else:
actual_seq_lengths = np.arange(
self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens + q_nope.shape[0],
self.speculative_num_draft_tokens,
)
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope,
c_kv_cache,
c_kv_cache,
query_rope=q_rope,
key_rope=k_rope_cache,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="TND",
scale=layer.scaling,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
sparse_mode=3,
atten_mask=self.mtp_mask,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
)
attn_output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
c_kv_cache,
c_kv_cache,
query_rope=q_rope,
key_rope=k_rope_cache,
num_heads=layer.tp_q_head_num,
num_key_value_heads=layer.tp_k_head_num,
input_layout="TND",
scale=layer.scaling,
antiquant_mode=0,
antiquant_scale=None,
block_table=self.forward_metadata.block_tables,
block_size=self.page_size,
sparse_mode=3,
atten_mask=self.mtp_mask,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
workspace=workspace,
out=[attn_output, softmax_lse],
)
attn_output = attn_output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
if (
not self.graph_mode
and forward_batch.num_token_non_padded_cpu != num_token_padding
):
attn_output = torch.cat(
[
attn_output,
attn_output.new_zeros(
num_token_padding - attn_output.shape[0], *attn_output.shape[1:]
),
],
dim=0,
)
return attn_output
def forward_decode_graph( def forward_decode_graph(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -690,3 +850,71 @@ class AscendAttnBackend(AttentionBackend): ...@@ -690,3 +850,71 @@ class AscendAttnBackend(AttentionBackend):
out=attn_output, out=attn_output,
) )
return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank) return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
class AscendAttnMultiStepDraftBackend:
"""
Wrap multiple Ascend attention backends as one for multiple consecutive
draft decoding steps
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.attn_backends = []
for _ in range(self.speculative_num_steps):
self.attn_backends.append(AscendAttnBackend(model_runner))
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
assert forward_batch.spec_info is not None
for i in range(self.speculative_num_steps - 1):
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs, max_num_tokens):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)
self.common_template(forward_batch, call_fn)
...@@ -77,6 +77,9 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo ...@@ -77,6 +77,9 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs, get_global_server_args from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.utils import flatten_nested_list from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.common import is_npu
_is_npu = is_npu()
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
...@@ -1050,7 +1053,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1050,7 +1053,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
has_grammar: bool = False has_grammar: bool = False
# Device # Device
if not _is_npu:
device: str = "cuda" device: str = "cuda"
else:
device: str = "npu"
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
......
...@@ -75,6 +75,10 @@ class NPUGraphRunner(CudaGraphRunner): ...@@ -75,6 +75,10 @@ class NPUGraphRunner(CudaGraphRunner):
# Replay # Replay
if not is_deepseek_nsa(self.model_runner.model_config.hf_config): if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
if forward_batch.forward_mode.is_target_verify():
seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs
seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)
else:
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * ( seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
self.bs - self.raw_bs self.bs - self.raw_bs
) )
......
...@@ -38,12 +38,13 @@ from sglang.srt.models.deepseek_v2 import ( ...@@ -38,12 +38,13 @@ from sglang.srt.models.deepseek_v2 import (
enable_nextn_moe_bf16_cast_to_fp8, enable_nextn_moe_bf16_cast_to_fp8,
) )
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_npu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu()
class DeepseekModelNextN(nn.Module): class DeepseekModelNextN(nn.Module):
...@@ -85,13 +86,21 @@ class DeepseekModelNextN(nn.Module): ...@@ -85,13 +86,21 @@ class DeepseekModelNextN(nn.Module):
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None self.alt_stream = torch.cuda.Stream() if _is_cuda else None
layer_name = "decoder"
if _is_npu and (
get_global_server_args().speculative_draft_model_path
== get_global_server_args().model_path
):
layer_name = "layers." + str(config.num_hidden_layers)
self.decoder = DeepseekV2DecoderLayer( self.decoder = DeepseekV2DecoderLayer(
config, config,
0, 0,
quant_config=quant_config, quant_config=quant_config,
moe_quant_config=moe_quant_config, moe_quant_config=moe_quant_config,
is_nextn=True, is_nextn=True,
prefix=add_prefix("decoder", prefix), prefix=add_prefix(layer_name, prefix),
alt_stream=self.alt_stream, alt_stream=self.alt_stream,
) )
......
...@@ -290,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch): ...@@ -290,6 +290,7 @@ def handle_attention_ascend(attn, forward_batch):
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_draft_extend_v2()
): ):
if hasattr(attn, "indexer"): if hasattr(attn, "indexer"):
return AttnForwardMethod.NPU_MLA_SPARSE return AttnForwardMethod.NPU_MLA_SPARSE
...@@ -3753,8 +3754,12 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3753,8 +3754,12 @@ class DeepseekV2ForCausalLM(nn.Module):
del self.lm_head.weight del self.lm_head.weight
self.model.embed_tokens.weight = embed self.model.embed_tokens.weight = embed
self.lm_head.weight = head self.lm_head.weight = head
if not _is_npu:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
else:
torch.npu.empty_cache()
torch.npu.synchronize()
@classmethod @classmethod
def get_model_config_for_expert_location(cls, config): def get_model_config_for_expert_location(cls, config):
......
...@@ -49,6 +49,7 @@ class DraftBackendFactory: ...@@ -49,6 +49,7 @@ class DraftBackendFactory:
"trtllm_mha": self._create_trtllm_mha_decode_backend, "trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend, "trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend, "nsa": self._create_nsa_decode_backend,
"ascend": self._create_ascend_decode_backend,
} }
return self._create_backend( return self._create_backend(
...@@ -72,6 +73,7 @@ class DraftBackendFactory: ...@@ -72,6 +73,7 @@ class DraftBackendFactory:
"trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend, "nsa": self._create_nsa_prefill_backend,
"ascend": self._create_ascend_prefill_backend,
} }
backend_name = ( backend_name = (
"decode_attention_backend" "decode_attention_backend"
...@@ -173,6 +175,15 @@ class DraftBackendFactory: ...@@ -173,6 +175,15 @@ class DraftBackendFactory:
self.draft_model_runner, self.topk, self.speculative_num_steps self.draft_model_runner, self.topk, self.speculative_num_steps
) )
def _create_ascend_decode_backend(self):
from sglang.srt.layers.attention.ascend_backend import (
AscendAttnMultiStepDraftBackend,
)
return AscendAttnMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_flashinfer_prefill_backend(self): def _create_flashinfer_prefill_backend(self):
if not get_global_server_args().use_mla_backend: if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
...@@ -219,6 +230,11 @@ class DraftBackendFactory: ...@@ -219,6 +230,11 @@ class DraftBackendFactory:
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False) return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
def _create_ascend_prefill_backend(self):
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
return AscendAttnBackend(self.draft_model_runner)
def _create_flashmla_prefill_backend(self): def _create_flashmla_prefill_backend(self):
logger.warning( logger.warning(
"flashmla prefill backend is not yet supported for draft extend." "flashmla prefill backend is not yet supported for draft extend."
......
...@@ -24,12 +24,13 @@ from sglang.srt.speculative.eagle_info_v2 import ( ...@@ -24,12 +24,13 @@ from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftInputV2Mixin, EagleDraftInputV2Mixin,
EagleVerifyInputV2Mixin, EagleVerifyInputV2Mixin,
) )
from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import ( from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN, SIMULATE_ACC_LEN,
TREE_SPEC_KERNEL_AVAILABLE, TREE_SPEC_KERNEL_AVAILABLE,
align_evict_mask_to_page_size, align_evict_mask_to_page_size,
assign_req_to_token_pool, assign_req_to_token_pool_func,
create_accept_length_filter, create_accept_length_filter,
create_extend_after_decode_spec_info, create_extend_after_decode_spec_info,
filter_finished_cache_loc_kernel, filter_finished_cache_loc_kernel,
...@@ -37,17 +38,16 @@ from sglang.srt.speculative.spec_utils import ( ...@@ -37,17 +38,16 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc, get_src_tgt_cache_loc,
get_target_cache_loc, get_target_cache_loc,
) )
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_npu, next_power_of_2
_is_npu = is_npu()
if is_cuda(): if is_cuda():
from sgl_kernel import ( from sgl_kernel import (
top_k_renorm_prob, top_k_renorm_prob,
top_p_renorm_prob, top_p_renorm_prob,
tree_speculative_sampling_target_only, tree_speculative_sampling_target_only,
verify_tree_greedy,
) )
elif is_hip():
from sgl_kernel import verify_tree_greedy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -77,18 +77,22 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -77,18 +77,22 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
@classmethod @classmethod
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int): def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
if not _is_npu:
device = "cuda"
else:
device = "npu"
return cls( return cls(
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"), draft_token=torch.empty((0,), dtype=torch.long, device=device),
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"), custom_mask=torch.full((0,), True, dtype=torch.bool, device=device),
positions=torch.empty((0,), dtype=torch.int64, device="cuda"), positions=torch.empty((0,), dtype=torch.int64, device=device),
retrive_index=torch.full( retrive_index=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda" (0, num_verify_tokens), -1, dtype=torch.long, device=device
), ),
retrive_next_token=torch.full( retrive_next_token=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda" (0, num_verify_tokens), -1, dtype=torch.long, device=device
), ),
retrive_next_sibling=torch.full( retrive_next_sibling=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda" (0, num_verify_tokens), -1, dtype=torch.long, device=device
), ),
retrive_cum_len=None, retrive_cum_len=None,
topk=topk, topk=topk,
...@@ -134,14 +138,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -134,14 +138,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
self.last_loc = last_loc self.last_loc = last_loc
bs = batch.batch_size() bs = batch.batch_size()
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool_func(
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.seq_lens, batch.seq_lens,
end_offset, end_offset,
batch.out_cache_loc, batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1], bs,
next_power_of_2(bs),
) )
def generate_attn_arg_prefill( def generate_attn_arg_prefill(
...@@ -151,16 +154,17 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -151,16 +154,17 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
req_to_token: torch.Tensor, req_to_token: torch.Tensor,
): ):
device = req_pool_indices.device
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
qo_indptr = torch.arange( qo_indptr = torch.arange(
0, 0,
(1 + batch_size) * self.draft_token_num, (1 + batch_size) * self.draft_token_num,
step=self.draft_token_num, step=self.draft_token_num,
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=device,
) )
cum_kv_seq_len = torch.zeros( cum_kv_seq_len = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda" (batch_size + 1,), dtype=torch.int32, device=device
) )
paged_kernel_lens = paged_kernel_lens + self.draft_token_num paged_kernel_lens = paged_kernel_lens + self.draft_token_num
...@@ -169,7 +173,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -169,7 +173,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
kv_indices = torch.empty( kv_indices = torch.empty(
paged_kernel_lens_sum + self.draft_token_num * batch_size, paged_kernel_lens_sum + self.draft_token_num * batch_size,
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=device,
) )
create_flashinfer_kv_indices_triton[(batch_size,)]( create_flashinfer_kv_indices_triton[(batch_size,)](
req_to_token, req_to_token,
...@@ -226,11 +230,11 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -226,11 +230,11 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
predict_shape = list(logits_output.next_token_logits.shape)[:-1] predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1 predict_shape[-1] += 1
predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda") predict = torch.empty(predict_shape, dtype=torch.int32, device=batch.device)
accept_index = torch.full( accept_index = torch.full(
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda" (bs, self.spec_steps + 1), -1, dtype=torch.int32, device=batch.device
) )
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") accept_length = torch.empty((bs,), dtype=torch.int32, device=batch.device)
if bs != len(sampling_info): if bs != len(sampling_info):
sampling_info = copy.deepcopy(sampling_info) sampling_info = copy.deepcopy(sampling_info)
...@@ -254,7 +258,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -254,7 +258,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
linear_penalty = torch.zeros( linear_penalty = torch.zeros(
(bs, logits_output.next_token_logits.shape[1]), (bs, logits_output.next_token_logits.shape[1]),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device=batch.device,
) )
sampling_info.apply_logits_bias(linear_penalty) sampling_info.apply_logits_bias(linear_penalty)
logits_output.next_token_logits.add_( logits_output.next_token_logits.add_(
...@@ -276,11 +280,10 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -276,11 +280,10 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
"Falling back to greedy verification." "Falling back to greedy verification."
) )
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE: if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE or _is_npu:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num) target_predict = target_predict.reshape(bs, self.draft_token_num)
predict, accept_index, accept_length = verify_tree_greedy_func(
verify_tree_greedy(
predicts=predict, # mutable predicts=predict, # mutable
accept_index=accept_index, # mutable accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable accept_token_num=accept_length, # mutable
...@@ -289,7 +292,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -289,7 +292,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
retrive_next_token=self.retrive_next_token, retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling, retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict, target_predict=target_predict,
topk=self.topk,
) )
else: else:
# apply temperature and get target probs # apply temperature and get target probs
expanded_temperature = torch.repeat_interleave( expanded_temperature = torch.repeat_interleave(
...@@ -315,14 +320,16 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -315,14 +320,16 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
target_probs = target_probs.reshape(bs, self.draft_token_num, -1) target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
draft_probs = torch.zeros( draft_probs = torch.zeros(
target_probs.shape, dtype=torch.float32, device="cuda" target_probs.shape, dtype=torch.float32, device=batch.device
) )
# coins for rejection sampling # coins for rejection sampling
coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda") coins = torch.rand_like(
candidates, dtype=torch.float32, device=batch.device
)
# coins for final sampling # coins for final sampling
coins_for_final_sampling = torch.rand( coins_for_final_sampling = torch.rand(
(bs,), dtype=torch.float32, device="cuda" (bs,), dtype=torch.float32, device=batch.device
) )
tree_speculative_sampling_target_only( tree_speculative_sampling_target_only(
predicts=predict, # mutable predicts=predict, # mutable
...@@ -468,14 +475,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -468,14 +475,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
if not has_finished: if not has_finished:
if page_size == 1 or self.topk == 1: if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[accept_index] batch.out_cache_loc = batch.out_cache_loc[accept_index]
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool_func(
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.seq_lens, batch.seq_lens,
batch.seq_lens + accept_length + 1, batch.seq_lens + accept_length + 1,
batch.out_cache_loc, batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1], bs,
next_power_of_2(bs),
) )
else: else:
batch.out_cache_loc = tgt_cache_loc batch.out_cache_loc = tgt_cache_loc
...@@ -501,14 +507,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): ...@@ -501,14 +507,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
) )
else: else:
if page_size == 1 or self.topk == 1: if page_size == 1 or self.topk == 1:
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool_func(
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.seq_lens, batch.seq_lens,
batch.seq_lens + accept_length + 1, batch.seq_lens + accept_length + 1,
batch.out_cache_loc[accept_index], batch.out_cache_loc[accept_index],
batch.req_to_token_pool.req_to_token.shape[1], bs,
next_power_of_2(bs),
) )
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1) batch.seq_lens_cpu.add_(accept_length_cpu + 1)
...@@ -695,17 +700,18 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): ...@@ -695,17 +700,18 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
req_to_token: torch.Tensor, req_to_token: torch.Tensor,
): ):
device = req_pool_indices.device
bs = self.accept_length.numel() bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
if paged_kernel_lens_sum is None: if paged_kernel_lens_sum is None:
paged_kernel_lens_sum = cum_kv_seq_len[-1] paged_kernel_lens_sum = cum_kv_seq_len[-1]
kv_indices = torch.empty( kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda" paged_kernel_lens_sum, dtype=torch.int32, device=device
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
......
...@@ -23,11 +23,16 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -23,11 +23,16 @@ from sglang.srt.model_executor.forward_batch_info import (
) )
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
from sglang.srt.speculative.spec_utils import ( from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN, SIMULATE_ACC_LEN,
generate_simulated_accept_index, generate_simulated_accept_index,
) )
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2 from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, is_npu, next_power_of_2
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
...@@ -41,11 +46,8 @@ if is_cuda(): ...@@ -41,11 +46,8 @@ if is_cuda():
top_k_renorm_prob, top_k_renorm_prob,
top_p_renorm_prob, top_p_renorm_prob,
tree_speculative_sampling_target_only, tree_speculative_sampling_target_only,
verify_tree_greedy,
) )
from sgl_kernel.top_k import fast_topk from sgl_kernel.top_k import fast_topk
elif is_hip():
from sgl_kernel import verify_tree_greedy
@triton.jit @triton.jit
...@@ -78,7 +80,7 @@ def assign_draft_cache_locs_page_size_1( ...@@ -78,7 +80,7 @@ def assign_draft_cache_locs_page_size_1(
@dataclass @dataclass
class EagleDraftInputV2Mixin: class EagleDraftInputV2Mixin:
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func
bs = batch.batch_size() bs = batch.batch_size()
...@@ -112,15 +114,15 @@ class EagleDraftInputV2Mixin: ...@@ -112,15 +114,15 @@ class EagleDraftInputV2Mixin:
extend_num_tokens, extend_num_tokens,
) )
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool_func(
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
self.allocate_lens, self.allocate_lens,
new_allocate_lens, new_allocate_lens,
out_cache_loc, out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1], bs,
next_power_of_2(bs),
) )
self.allocate_lens = new_allocate_lens self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional # FIXME(lsyin): make this sync optional
...@@ -199,22 +201,16 @@ class EagleVerifyInputV2Mixin: ...@@ -199,22 +201,16 @@ class EagleVerifyInputV2Mixin:
bs = len(batch.req_pool_indices) bs = len(batch.req_pool_indices)
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
device = batch.input_ids.device device = batch.input_ids.device
batch.out_cache_loc = torch.empty( batch.out_cache_loc = assign_extend_cache_locs_func(
(bs * self.draft_token_num,), req_pool_indices=batch.req_pool_indices,
dtype=torch.int64, req_to_token=req_to_token_pool.req_to_token,
start_offset=batch.seq_lens,
end_offset=batch.seq_lens + self.draft_token_num,
batch_size=bs,
draft_token_num=self.draft_token_num,
device=device, device=device,
) )
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
# Get a forward batch # Get a forward batch
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.capture_hidden_mode = CaptureHiddenMode.FULL batch.capture_hidden_mode = CaptureHiddenMode.FULL
...@@ -258,11 +254,10 @@ class EagleVerifyInputV2Mixin: ...@@ -258,11 +254,10 @@ class EagleVerifyInputV2Mixin:
accept_length = torch.empty((bs,), dtype=torch.int32, device=device) accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
# Sample tokens # Sample tokens
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy or _is_npu:
target_predict = torch.argmax(next_token_logits, dim=-1) target_predict = torch.argmax(next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num) target_predict = target_predict.reshape(bs, self.draft_token_num)
predict, accept_index, accept_length = verify_tree_greedy_func(
verify_tree_greedy(
predicts=predict, # mutable predicts=predict, # mutable
accept_index=accept_index, # mutable accept_index=accept_index, # mutable
accept_token_num=accept_length, # mutable accept_token_num=accept_length, # mutable
...@@ -271,6 +266,7 @@ class EagleVerifyInputV2Mixin: ...@@ -271,6 +266,7 @@ class EagleVerifyInputV2Mixin:
retrive_next_token=self.retrive_next_token, retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling, retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict, target_predict=target_predict,
topk=self.topk,
) )
else: else:
# Apply temperature and get target probs # Apply temperature and get target probs
...@@ -338,7 +334,7 @@ class EagleVerifyInputV2Mixin: ...@@ -338,7 +334,7 @@ class EagleVerifyInputV2Mixin:
return predict, accept_length, accept_index return predict, accept_length, accept_index
@torch.compile(dynamic=True) @torch.compile(dynamic=True, disable=_is_npu)
def select_top_k_tokens_tmp( def select_top_k_tokens_tmp(
i: int, i: int,
topk_p: torch.Tensor, topk_p: torch.Tensor,
...@@ -456,3 +452,50 @@ def assign_extend_cache_locs( ...@@ -456,3 +452,50 @@ def assign_extend_cache_locs(
tl.store(out_cache_ptr + save_offset, data, mask=mask) tl.store(out_cache_ptr + save_offset, data, mask=mask)
load_offset += BLOCK_SIZE load_offset += BLOCK_SIZE
save_offset += BLOCK_SIZE save_offset += BLOCK_SIZE
def assign_extend_cache_locs_func(
req_pool_indices: torch.Tensor,
req_to_token: torch.Tensor,
start_offset: torch.Tensor,
end_offset: torch.Tensor,
batch_size: int,
draft_token_num: int,
device,
) -> torch.Tensor:
if _is_cuda or _is_hip:
out_cache_loc = torch.empty(
(batch_size * draft_token_num,),
dtype=torch.int64,
device=device,
)
assign_extend_cache_locs[(batch_size,)](
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
req_to_token.shape[1],
next_power_of_2(batch_size),
)
return out_cache_loc
elif _is_npu:
import sgl_kernel_npu # noqa: F401
out_cache_loc = torch.empty(
(batch_size * draft_token_num,),
dtype=torch.int32,
device=device,
)
torch.ops.npu.cache_loc_update(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
)
out_cache_loc = out_cache_loc.to(dtype=torch.int64)
return out_cache_loc
...@@ -4,14 +4,128 @@ from typing import List, Optional ...@@ -4,14 +4,128 @@ from typing import List, Optional
import torch import torch
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip, is_npu
if is_cuda() or is_hip(): _is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
if _is_cuda or _is_hip:
from sgl_kernel import ( from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
) )
def build_tree_efficient_native(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
topk: int,
draft_token_num: int,
tree_mask_mode: int,
bs: int,
):
# Generate batch and token index ranges
bs_range = torch.arange(bs, device=tree_mask.device).view(-1, 1)
draft_token_num_range = torch.arange(draft_token_num, device=tree_mask.device)
# Optimized common case for performance.
if draft_token_num == 2 and topk == 1 and tree_mask_mode == TreeMaskMode.FULL_MASK:
positions = verified_seq_len.repeat_interleave(draft_token_num)
positions = (positions.view(bs, -1) + draft_token_num_range).view(-1)
retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
retrive_next_token[:, 0] = 1
retrive_next_token[:, 1] = -1
return (
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
tree_mask,
)
# Precompute sequence tree indices
draft_token_num_range1 = torch.arange(draft_token_num - 1, device=tree_mask.device)
cum_seq_len = torch.cumsum(verified_seq_len * draft_token_num, dim=0)
cum_seq_len = torch.cat((torch.tensor([0], device=tree_mask.device), cum_seq_len))
cum_seq_len = cum_seq_len[:-1]
seq_tree_idx = (
draft_token_num * draft_token_num * torch.arange(bs, device=tree_mask.device)
+ cum_seq_len
)
# Batch processing for tree mask
if tree_mask_mode == TreeMaskMode.FULL_MASK:
token_tree_base = (
seq_tree_idx.view(-1, 1)
+ (verified_seq_len.view(-1, 1) + draft_token_num) * draft_token_num_range
)
token_tree_indices = token_tree_base + verified_seq_len.view(-1, 1) + 1
else:
token_tree_indices = (
bs_range * draft_token_num**2 + draft_token_num_range * draft_token_num + 1
)
tree_mask[token_tree_indices.flatten() - 1] = True
indices = token_tree_indices.unsqueeze(-1) + draft_token_num_range1.view(1, 1, -1)
tree_mask[indices.view(-1)] = False
positions = verified_seq_len.repeat_interleave(draft_token_num)
parent_tb_indices = selected_index // topk
retrive_index[:] = bs_range * draft_token_num + draft_token_num_range
tree_mask[token_tree_indices.view(-1, 1) + draft_token_num_range1] = True
for bid in range(bs):
for tid in range(draft_token_num):
position = 0
if tid == 0:
# Process root node
for i in range(draft_token_num - 1, 0, -1):
parent_position = 0
parent_tb_idx = parent_tb_indices[bid][i - 1]
if parent_tb_idx > 0:
parent_token_idx = parent_list[bid][parent_tb_idx]
loop_num = draft_token_num - parent_position
for _ in range(loop_num):
if selected_index[bid][parent_position] == parent_token_idx:
parent_position += 1
break
parent_position += 1
if parent_position == draft_token_num:
continue
if retrive_next_token[bid][parent_position] != -1:
retrive_next_sibling[bid][i] = retrive_next_token[bid][
parent_position
]
retrive_next_token[bid][parent_position] = i
else:
# Process no-root nodes
cur_position = tid - 1
while True:
position += 1
if cur_position >= draft_token_num:
tree_mask[token_tree_indices + cur_position] = True
parent_tb_idx = selected_index[bid][cur_position] // topk
else:
parent_tb_idx = parent_tb_indices[bid][cur_position]
if parent_tb_idx == 0:
break
token_idx = parent_list[bid][parent_tb_idx]
cur_position = 0
for _ in range(draft_token_num):
if selected_index[bid][cur_position] == token_idx:
break
cur_position += 1
positions[bid * draft_token_num + tid] += position
return positions, retrive_index, retrive_next_token, retrive_next_sibling, tree_mask
def organize_draft_results( def organize_draft_results(
score_list: List[torch.Tensor], score_list: List[torch.Tensor],
token_list: List[torch.Tensor], token_list: List[torch.Tensor],
...@@ -114,6 +228,27 @@ def build_tree_kernel_efficient( ...@@ -114,6 +228,27 @@ def build_tree_kernel_efficient(
(bs * num_verify_tokens,), device=device, dtype=torch.long (bs * num_verify_tokens,), device=device, dtype=torch.long
) )
if _is_npu:
(
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
tree_mask,
) = build_tree_efficient_native(
parent_list,
top_scores_index,
seq_lens,
tree_mask,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
num_verify_tokens,
tree_mask_mode,
bs,
)
else:
sgl_build_tree_kernel_efficient( sgl_build_tree_kernel_efficient(
parent_list, parent_list,
top_scores_index, top_scores_index,
...@@ -136,3 +271,113 @@ def build_tree_kernel_efficient( ...@@ -136,3 +271,113 @@ def build_tree_kernel_efficient(
retrive_next_sibling, retrive_next_sibling,
draft_tokens, draft_tokens,
) )
def verify_tree_greedy_native(
predicts: torch.Tensor,
accept_index: torch.Tensor,
accept_token_num: torch.Tensor,
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor,
topk: int = -1,
):
batch_size, num_draft_tokens = candidates.shape
# Optimized common case for performance.
if num_draft_tokens == 2 and accept_index.shape[1] == 2 and topk == 1:
comparison_result = candidates[:, 1] == target_predict[:, 0]
predicts = target_predict.flatten()
accept_index = torch.arange(
0, num_draft_tokens * batch_size, device=candidates.device, dtype=torch.long
).reshape(batch_size, num_draft_tokens)
comparison_result = comparison_result.to(torch.int64)
accept_index_mask = accept_index[:, 1] * comparison_result
accept_index[:, 1] = accept_index_mask - (1 - comparison_result)
accept_token_num = comparison_result.int()
return predicts, accept_index, accept_token_num
# BFS
for bx in range(batch_size):
cur_candidates = candidates[bx]
cur_retrive_index = retrive_index[bx]
cur_next_token = retrive_next_token[bx]
cur_next_sibling = retrive_next_sibling[bx]
cur_target = target_predict[bx]
last_accepted_idx = cur_retrive_index[0]
accept_index[bx, 0] = last_accepted_idx
num_accepted = 0
cur_node = 0
for _ in range(1, num_draft_tokens):
cur_node = cur_next_token[cur_node]
found = False
while cur_node != -1:
draft_idx = cur_retrive_index[cur_node]
draft_token = cur_candidates[cur_node]
target_token = cur_target[last_accepted_idx - num_draft_tokens * bx]
if draft_token == target_token:
predicts[last_accepted_idx] = target_token
num_accepted += 1
accept_index[bx, num_accepted] = draft_idx
last_accepted_idx = draft_idx
found = True
break
else:
cur_node = cur_next_sibling[cur_node]
if not found:
break
accept_token_num[bx] = num_accepted
predicts[last_accepted_idx] = cur_target[
last_accepted_idx - num_draft_tokens * bx
]
return predicts, accept_index, accept_token_num
def verify_tree_greedy_func(
predicts: torch.Tensor,
accept_index: torch.Tensor,
accept_token_num: torch.Tensor,
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor,
topk: int = -1,
):
if _is_cuda or _is_hip:
from sgl_kernel import verify_tree_greedy
verify_tree_greedy(
predicts=predicts, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_token_num, # mutable
candidates=candidates,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
target_predict=target_predict,
)
elif _is_npu:
predicts, accept_index, accept_token_num = verify_tree_greedy_native(
predicts=predicts, # mutable
accept_index=accept_index, # mutable
accept_token_num=accept_token_num, # mutable
candidates=candidates,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
target_predict=target_predict,
topk=topk,
)
return predicts, accept_index, accept_token_num
...@@ -53,9 +53,12 @@ from sglang.srt.utils import ( ...@@ -53,9 +53,12 @@ from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
is_cuda, is_cuda,
is_npu,
next_power_of_2, next_power_of_2,
) )
_is_npu = is_npu()
if is_cuda(): if is_cuda():
from sgl_kernel import segment_packbits # noqa: F401 from sgl_kernel import segment_packbits # noqa: F401
...@@ -205,7 +208,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -205,7 +208,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = None self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph: if self.server_args.disable_cuda_graph or _is_npu:
return return
# Capture draft # Capture draft
...@@ -945,7 +948,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -945,7 +948,7 @@ class EAGLEWorker(TpModelWorker):
draft_input.hidden_states = logits_output.hidden_states draft_input.hidden_states = logits_output.hidden_states
@torch.compile(dynamic=True) @torch.compile(dynamic=True, disable=_is_npu)
def get_last_loc_large_page_size_top_k_1( def get_last_loc_large_page_size_top_k_1(
req_to_token: torch.Tensor, req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch.cuda import Stream as CudaStream
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
...@@ -38,18 +37,21 @@ from sglang.srt.utils.common import ( ...@@ -38,18 +37,21 @@ from sglang.srt.utils.common import (
empty_context, empty_context,
fast_topk, fast_topk,
get_available_gpu_memory, get_available_gpu_memory,
is_npu,
next_power_of_2, next_power_of_2,
) )
_is_npu = is_npu()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_plan_stream( def _get_plan_stream(
device: str, device: str,
) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]: ) -> Tuple[any, contextlib.AbstractContextManager]:
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
plan_stream: CudaStream = torch.get_device_module(device).Stream() plan_stream = torch.get_device_module(device).Stream()
plan_stream_ctx = torch.cuda.stream(plan_stream) plan_stream_ctx = torch.get_device_module(device).stream(plan_stream)
return plan_stream, plan_stream_ctx return plan_stream, plan_stream_ctx
else: else:
return None, contextlib.nullcontext() return None, contextlib.nullcontext()
...@@ -206,7 +208,7 @@ class EagleDraftWorker(BaseDraftWorker): ...@@ -206,7 +208,7 @@ class EagleDraftWorker(BaseDraftWorker):
self.cuda_graph_runner = None self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None self.cuda_graph_runner_for_draft_extend = None
if self.server_args.disable_cuda_graph: if self.server_args.disable_cuda_graph or _is_npu:
return return
# Capture draft # Capture draft
...@@ -456,7 +458,9 @@ class EagleDraftWorker(BaseDraftWorker): ...@@ -456,7 +458,9 @@ class EagleDraftWorker(BaseDraftWorker):
) )
if self.plan_stream: if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream) torch.get_device_module(self.device).current_stream().wait_stream(
self.plan_stream
)
# Run draft extend batch in the main compute stream # Run draft extend batch in the main compute stream
draft_logits_output = self.draft_runner.model.forward( draft_logits_output = self.draft_runner.model.forward(
...@@ -577,7 +581,9 @@ class EAGLEWorkerV2(BaseSpecWorker): ...@@ -577,7 +581,9 @@ class EAGLEWorkerV2(BaseSpecWorker):
# Since batch.seq_lens is allocated in another stream, we need # Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory # record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running. # while forward_stream is still running.
batch.seq_lens.record_stream(torch.cuda.current_stream()) batch.seq_lens.record_stream(
torch.get_device_module(self.device).current_stream()
)
# Parse args # Parse args
verify_input: EagleVerifyInput = batch.spec_info verify_input: EagleVerifyInput = batch.spec_info
...@@ -596,7 +602,7 @@ class EAGLEWorkerV2(BaseSpecWorker): ...@@ -596,7 +602,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
# Correct some buffers due to the overlap plan # Correct some buffers due to the overlap plan
if self.plan_stream: if self.plan_stream:
torch.cuda.current_stream().wait_stream(self.plan_stream) torch.get_device_module().current_stream().wait_stream(self.plan_stream)
# Some values such as custom_mask and position depend on the output of draft, # Some values such as custom_mask and position depend on the output of draft,
# so the previous plan step used the wrong values. Here, we need to run the related # so the previous plan step used the wrong values. Here, we need to run the related
...@@ -628,7 +634,7 @@ class EAGLEWorkerV2(BaseSpecWorker): ...@@ -628,7 +634,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
accept_index, accept_index,
) = verify_input.sample(batch, logits_output) ) = verify_input.sample(batch, logits_output)
new_seq_lens = batch.seq_lens + accept_length new_seq_lens = batch.seq_lens + accept_length
verify_done = torch.cuda.Event() verify_done = torch.get_device_module(self.device).Event()
verify_done.record() verify_done.record()
all_verified_id = predict[accept_index] all_verified_id = predict[accept_index]
......
...@@ -19,16 +19,22 @@ from sglang.srt.distributed.parallel_state import ( ...@@ -19,16 +19,22 @@ from sglang.srt.distributed.parallel_state import (
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip, is_npu, next_power_of_2
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_info import EagleVerifyInput from sglang.srt.speculative.eagle_info import EagleVerifyInput
if is_cuda(): if _is_cuda:
from sgl_kernel import fast_topk from sgl_kernel import fast_topk
elif is_hip(): elif _is_hip:
from sgl_kernel import fast_topk from sgl_kernel import fast_topk
else:
from sglang.srt.utils.common import fast_topk
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -39,7 +45,7 @@ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0 ...@@ -39,7 +45,7 @@ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get() SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE = is_cuda() # This kernel is only available for CUDA now TREE_SPEC_KERNEL_AVAILABLE = _is_cuda # This kernel is only available for CUDA now
@triton.jit @triton.jit
...@@ -103,6 +109,36 @@ def assign_req_to_token_pool( ...@@ -103,6 +109,36 @@ def assign_req_to_token_pool(
load_offset += BLOCK_SIZE load_offset += BLOCK_SIZE
def assign_req_to_token_pool_func(
req_pool_indices: torch.Tensor,
req_to_token: torch.Tensor,
start_offset: torch.Tensor,
end_offset: torch.Tensor,
out_cache_loc: torch.Tensor,
batch_size: int,
):
if _is_cuda or _is_hip:
assign_req_to_token_pool[(batch_size,)](
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
req_to_token.shape[1],
next_power_of_2(batch_size),
)
elif _is_npu:
import sgl_kernel_npu # noqa: F401
torch.ops.npu.cache_loc_assign(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
)
@triton.jit @triton.jit
def assign_draft_cache_locs( def assign_draft_cache_locs(
req_pool_indices, req_pool_indices,
...@@ -331,7 +367,7 @@ def get_target_cache_loc( ...@@ -331,7 +367,7 @@ def get_target_cache_loc(
) )
@torch.compile(dynamic=True) @torch.compile(dynamic=True, disable=_is_npu)
def get_src_tgt_cache_loc( def get_src_tgt_cache_loc(
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
out_cache_loc: torch.Tensor, out_cache_loc: torch.Tensor,
...@@ -381,7 +417,7 @@ def filter_finished_cache_loc_kernel( ...@@ -381,7 +417,7 @@ def filter_finished_cache_loc_kernel(
) )
@torch.compile(dynamic=True) @torch.compile(dynamic=True, disable=_is_npu)
def create_accept_length_filter( def create_accept_length_filter(
accept_length: torch.Tensor, accept_length: torch.Tensor,
unfinished_index_device: torch.Tensor, unfinished_index_device: torch.Tensor,
...@@ -395,7 +431,7 @@ def create_accept_length_filter( ...@@ -395,7 +431,7 @@ def create_accept_length_filter(
return accept_length_filter return accept_length_filter
@torch.compile(dynamic=True) @torch.compile(dynamic=True, disable=_is_npu)
def select_top_k_tokens( def select_top_k_tokens(
i: int, i: int,
topk_p: torch.Tensor, topk_p: torch.Tensor,
...@@ -413,7 +449,7 @@ def select_top_k_tokens( ...@@ -413,7 +449,7 @@ def select_top_k_tokens(
tree_info = ( tree_info = (
topk_p.unsqueeze(1), # shape: (b, 1, topk) topk_p.unsqueeze(1), # shape: (b, 1, topk)
topk_index, # shape: (b, topk) topk_index, # shape: (b, topk)
torch.arange(-1, topk, dtype=torch.long, device="cuda") torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device)
.unsqueeze(0) .unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
) )
......
...@@ -3106,12 +3106,16 @@ def apply_module_patch(target_module, target_function, wrappers): ...@@ -3106,12 +3106,16 @@ def apply_module_patch(target_module, target_function, wrappers):
setattr(original_module, target_function, candidate) setattr(original_module, target_function, candidate)
for key, value in sys.modules.copy().items(): for key, value in sys.modules.copy().items():
try:
if ( if (
target_function is not None target_function is not None
and hasattr(value, target_function) and hasattr(value, target_function)
and id(getattr(value, target_function)) == original_function_id and id(getattr(value, target_function)) == original_function_id
): ):
setattr(value, target_function, candidate) setattr(value, target_function, candidate)
except ImportError as e:
# Ignore some modules reporting ImportError when calling hasattr
logger.warning(f"Ignore {value} reports ImportError with:\n{str(e)}")
def parse_module_path(module_path, function_name, create_dummy): def parse_module_path(module_path, function_name, create_dummy):
......
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-R1-0528-W8A8": {
"accuracy": 0.95,
"latency": 1000,
"output_throughput": 6,
},
}
class TestAscendDeepSeekMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--attention-backend",
"ascend",
"--quantization",
"w8a8_int8",
"--mem-fraction-static",
0.8,
"--disable-radix-cache",
"--chunked-prefill-size",
32768,
"--tp-size",
16,
"--speculative-algorithm",
"NEXTN",
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
]
cls.extra_envs = {
"SGLANG_NPU_USE_MLAPO": "1",
}
os.environ.update(cls.extra_envs)
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=1500,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()
...@@ -359,6 +359,7 @@ suite_ascend = { ...@@ -359,6 +359,7 @@ suite_ascend = {
], ],
"per-commit-16-ascend-a3": [ "per-commit-16-ascend-a3": [
TestFile("ascend/test_ascend_deepep.py", 400), TestFile("ascend/test_ascend_deepep.py", 400),
TestFile("ascend/test_ascend_deepseek_mtp.py", 400),
], ],
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment