Unverified Commit 16cd550c authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

Support Qwen3-Next on Ascend NPU (#10379)

parent d5e2a374
...@@ -73,6 +73,6 @@ jobs: ...@@ -73,6 +73,6 @@ jobs:
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }} push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance: false provenance: false
build-args: | build-args: |
SGLANG_KERNEL_NPU_TAG=20250901 SGLANG_KERNEL_NPU_TAG=20250913
CANN_VERSION=${{ matrix.cann_version }} CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }} DEVICE_TYPE=${{ matrix.device_type }}
...@@ -69,6 +69,6 @@ jobs: ...@@ -69,6 +69,6 @@ jobs:
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }} push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance: false provenance: false
build-args: | build-args: |
SGLANG_KERNEL_NPU_TAG=20250901 SGLANG_KERNEL_NPU_TAG=20250913
CANN_VERSION=${{ matrix.cann_version }} CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }} DEVICE_TYPE=${{ matrix.device_type }}
...@@ -13,7 +13,8 @@ ARG PYTORCH_VERSION=2.6.0 ...@@ -13,7 +13,8 @@ ARG PYTORCH_VERSION=2.6.0
ARG TORCHVISION_VERSION=0.21.0 ARG TORCHVISION_VERSION=0.21.0
ARG PTA_URL="https://gitee.com/ascend/pytorch/releases/download/v7.1.0.1-pytorch2.6.0/torch_npu-2.6.0.post1-cp311-cp311-manylinux_2_28_aarch64.whl" ARG PTA_URL="https://gitee.com/ascend/pytorch/releases/download/v7.1.0.1-pytorch2.6.0/torch_npu-2.6.0.post1-cp311-cp311-manylinux_2_28_aarch64.whl"
ARG VLLM_TAG=v0.8.5 ARG VLLM_TAG=v0.8.5
ARG TRITON_ASCEND_URL=https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0.dev20250729-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl ARG TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0%2Bgitb0ea0850-cp311-cp311-linux_aarch64.whl"
ARG BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/Ascend-BiSheng-toolkit_aarch64.run"
ARG SGLANG_TAG=main ARG SGLANG_TAG=main
ARG ASCEND_CANN_PATH=/usr/local/Ascend/ascend-toolkit ARG ASCEND_CANN_PATH=/usr/local/Ascend/ascend-toolkit
ARG SGLANG_KERNEL_NPU_TAG=main ARG SGLANG_KERNEL_NPU_TAG=main
...@@ -81,13 +82,17 @@ RUN git clone https://github.com/sgl-project/sglang --branch $SGLANG_TAG && \ ...@@ -81,13 +82,17 @@ RUN git clone https://github.com/sgl-project/sglang --branch $SGLANG_TAG && \
rm -rf sglang rm -rf sglang
# Install Deep-ep # Install Deep-ep
RUN git clone --branch $SGLANG_KERNEL_NPU_TAG https://github.com/sgl-project/sgl-kernel-npu.git \ # pin wheel to 0.45.1 ref: https://github.com/pypa/wheel/issues/662
RUN pip install wheel==0.45.1 && git clone --branch $SGLANG_KERNEL_NPU_TAG https://github.com/sgl-project/sgl-kernel-npu.git \
&& export LD_LIBRARY_PATH=${ASCEND_CANN_PATH}/latest/runtime/lib64/stub:$LD_LIBRARY_PATH && \ && export LD_LIBRARY_PATH=${ASCEND_CANN_PATH}/latest/runtime/lib64/stub:$LD_LIBRARY_PATH && \
source ${ASCEND_CANN_PATH}/set_env.sh && \ source ${ASCEND_CANN_PATH}/set_env.sh && \
cd sgl-kernel-npu && \ cd sgl-kernel-npu && \
bash build.sh \ bash build.sh \
&& pip install output/deep_ep*.whl --no-cache-dir \ && pip install output/deep_ep*.whl output/sgl_kernel_npu*.whl --no-cache-dir \
&& cd .. && rm -rf sgl-kernel-npu \ && cd .. && rm -rf sgl-kernel-npu \
&& cd "$(pip show deep-ep | awk '/^Location:/ {print $2}')" && ln -s deep_ep/deep_ep_cpp*.so && cd "$(pip show deep-ep | awk '/^Location:/ {print $2}')" && ln -s deep_ep/deep_ep_cpp*.so
# Install Bisheng
RUN wget ${BISHENG_URL} && chmod a+x Ascend-BiSheng-toolkit_aarch64.run && ./Ascend-BiSheng-toolkit_aarch64.run --install && rm Ascend-BiSheng-toolkit_aarch64.run
CMD ["/bin/bash"] CMD ["/bin/bash"]
...@@ -158,7 +158,7 @@ def _layer_norm_fwd( ...@@ -158,7 +158,7 @@ def _layer_norm_fwd(
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8) num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups) grid = (M, ngroups)
with torch.cuda.device(x.device.index): with torch.get_device_module(x.device).device(x.device.index):
_layer_norm_fwd_1pass_kernel[grid]( _layer_norm_fwd_1pass_kernel[grid](
x, x,
out, out,
......
...@@ -23,6 +23,22 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo ...@@ -23,6 +23,22 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_npu
if is_npu():
from sgl_kernel_npu.fla.chunk import chunk_gated_delta_rule_npu
from sgl_kernel_npu.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update_npu,
)
from sgl_kernel_npu.mamba.causal_conv1d import (
causal_conv1d_fn_npu,
causal_conv1d_update_npu,
)
chunk_gated_delta_rule = chunk_gated_delta_rule_npu
fused_sigmoid_gating_delta_rule_update = fused_sigmoid_gating_delta_rule_update_npu
causal_conv1d_fn = causal_conv1d_fn_npu
causal_conv1d_update = causal_conv1d_update_npu
@dataclass @dataclass
...@@ -85,10 +101,12 @@ class MambaAttnBackend(AttentionBackend): ...@@ -85,10 +101,12 @@ class MambaAttnBackend(AttentionBackend):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(max_bs): for i in range(max_bs):
self.state_indices_list.append( self.state_indices_list.append(
torch.full((i + 1,), self.pad_slot_id, dtype=torch.int32, device="cuda") torch.full(
(i + 1,), self.pad_slot_id, dtype=torch.int32, device=self.device
)
) )
self.query_start_loc_list.append( self.query_start_loc_list.append(
torch.empty((i + 2,), dtype=torch.int32, device="cuda") torch.empty((i + 2,), dtype=torch.int32, device=self.device)
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
...@@ -110,7 +128,7 @@ class MambaAttnBackend(AttentionBackend): ...@@ -110,7 +128,7 @@ class MambaAttnBackend(AttentionBackend):
bs * spec_info.draft_token_num + 1, bs * spec_info.draft_token_num + 1,
step=spec_info.draft_token_num, step=spec_info.draft_token_num,
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=self.device,
) )
) )
else: else:
...@@ -152,7 +170,7 @@ class MambaAttnBackend(AttentionBackend): ...@@ -152,7 +170,7 @@ class MambaAttnBackend(AttentionBackend):
bs * spec_info.draft_token_num + 1, bs * spec_info.draft_token_num + 1,
step=spec_info.draft_token_num, step=spec_info.draft_token_num,
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=self.device,
) )
) )
if num_padding > 0: if num_padding > 0:
......
...@@ -649,6 +649,7 @@ class HybridLinearKVPool(KVCache): ...@@ -649,6 +649,7 @@ class HybridLinearKVPool(KVCache):
self, self,
size: int, size: int,
dtype: torch.dtype, dtype: torch.dtype,
page_size: int,
head_num: int, head_num: int,
head_dim: int, head_dim: int,
full_attention_layer_ids: List[int], full_attention_layer_ids: List[int],
...@@ -659,10 +660,14 @@ class HybridLinearKVPool(KVCache): ...@@ -659,10 +660,14 @@ class HybridLinearKVPool(KVCache):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.full_layer_nums = len(full_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = 1 self.page_size = page_size
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose assert not enable_kvcache_transpose
self.full_kv_pool = MHATokenToKVPool( if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size, size=size,
page_size=self.page_size, page_size=self.page_size,
dtype=dtype, dtype=dtype,
...@@ -904,8 +909,12 @@ class AscendTokenToKVPool(MHATokenToKVPool): ...@@ -904,8 +909,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
cache_v: torch.Tensor, cache_v: torch.Tensor,
k_scale: Optional[float] = None, k_scale: Optional[float] = None,
v_scale: Optional[float] = None, v_scale: Optional[float] = None,
layer_id_override: Optional[int] = None,
): ):
layer_id = layer.layer_id if layer_id_override is not None:
layer_id = layer_id_override
else:
layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
if k_scale is not None: if k_scale is not None:
cache_k.div_(k_scale) cache_k.div_(k_scale)
......
...@@ -1567,6 +1567,7 @@ class ModelRunner: ...@@ -1567,6 +1567,7 @@ class ModelRunner:
) )
elif self.is_hybrid_gdn: elif self.is_hybrid_gdn:
self.token_to_kv_pool = HybridLinearKVPool( self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size if _is_npu else 1,
size=self.max_total_num_tokens, size=self.max_total_num_tokens,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads( head_num=self.model_config.get_num_kv_heads(
...@@ -1601,7 +1602,10 @@ class ModelRunner: ...@@ -1601,7 +1602,10 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator # Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None: if self.token_to_kv_pool_allocator is None:
if self.server_args.attention_backend == "ascend": if _is_npu and self.server_args.attention_backend in [
"ascend",
"hybrid_linear_attn",
]:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
...@@ -1819,15 +1823,22 @@ class ModelRunner: ...@@ -1819,15 +1823,22 @@ class ModelRunner:
assert ( assert (
self.is_hybrid_gdn self.is_hybrid_gdn
), "hybrid_linear_attn backend can only be used with hybrid GDN models." ), "hybrid_linear_attn backend can only be used with hybrid GDN models."
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend, HybridLinearAttnBackend,
MambaAttnBackend, MambaAttnBackend,
) )
full_attn_backend = FlashAttentionBackend(self) if _is_npu:
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
full_attn_backend = AscendAttnBackend(self)
else:
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
full_attn_backend = FlashAttentionBackend(self)
linear_attn_backend = MambaAttnBackend(self) linear_attn_backend = MambaAttnBackend(self)
full_attn_layers = self.model_config.hf_config.full_attention_layer_ids full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
return HybridLinearAttnBackend( return HybridLinearAttnBackend(
......
...@@ -46,10 +46,11 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -46,10 +46,11 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.utils import add_prefix, is_cuda, make_layers, set_weight_attrs from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu()
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -327,7 +328,7 @@ class Qwen3GatedDeltaNet(nn.Module): ...@@ -327,7 +328,7 @@ class Qwen3GatedDeltaNet(nn.Module):
eps=self.layer_norm_epsilon, eps=self.layer_norm_epsilon,
group_size=None, group_size=None,
norm_before_gate=True, norm_before_gate=True,
device=torch.cuda.current_device(), device=torch.get_device_module().current_device(),
dtype=config.torch_dtype, dtype=config.torch_dtype,
) )
...@@ -388,7 +389,7 @@ class Qwen3GatedDeltaNet(nn.Module): ...@@ -388,7 +389,7 @@ class Qwen3GatedDeltaNet(nn.Module):
return query, key, value, z, b, a return query, key, value, z, b, a
def _forward_input_proj(self, hidden_states: torch.Tensor): def _forward_input_proj(self, hidden_states: torch.Tensor):
DUAL_STREAM_TOKEN_THRESHOLD = 1024 DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
seq_len, _ = hidden_states.shape seq_len, _ = hidden_states.shape
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD: if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -454,6 +455,8 @@ class Qwen3GatedDeltaNet(nn.Module): ...@@ -454,6 +455,8 @@ class Qwen3GatedDeltaNet(nn.Module):
"dt_bias": self.dt_bias, "dt_bias": self.dt_bias,
"layer_id": self.layer_id, "layer_id": self.layer_id,
"seq_len": seq_len, "seq_len": seq_len,
"num_k_heads": self.num_k_heads,
"num_v_heads": self.num_v_heads,
"z": z, "z": z,
} }
......
...@@ -38,6 +38,7 @@ from sglang.srt.utils import ( ...@@ -38,6 +38,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_npu,
is_port_available, is_port_available,
is_remote_url, is_remote_url,
is_sm90_supported, is_sm90_supported,
...@@ -569,7 +570,7 @@ class ServerArgs: ...@@ -569,7 +570,7 @@ class ServerArgs:
) )
self.disable_cuda_graph = True self.disable_cuda_graph = True
if self.attention_backend == "ascend": if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]:
logger.warning( logger.warning(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128." "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
) )
......
...@@ -45,16 +45,22 @@ wget -O "${PTA_NAME}" "${PTA_URL}" && ${PIP_INSTALL} "./${PTA_NAME}" ...@@ -45,16 +45,22 @@ wget -O "${PTA_NAME}" "${PTA_URL}" && ${PIP_INSTALL} "./${PTA_NAME}"
### Install Triton-Ascend ### Install Triton-Ascend
TRITON_ASCEND_NAME="triton_ascend-3.2.0.dev20250729-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" TRITON_ASCEND_NAME="triton_ascend-3.2.0+gitb0ea0850-cp311-cp311-linux_aarch64.whl"
TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${TRITON_ASCEND_NAME}" TRITON_ASCEND_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/triton_ascend-3.2.0%2Bgitb0ea0850-cp311-cp311-linux_aarch64.whl"
${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11 ${PIP_INSTALL} attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11
wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}" wget -O "${TRITON_ASCEND_NAME}" "${TRITON_ASCEND_URL}" && ${PIP_INSTALL} "./${TRITON_ASCEND_NAME}"
### Install BiSheng
BISHENG_NAME="Ascend-BiSheng-toolkit_aarch64.run"
BISHENG_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com/sglang/${BISHENG_NAME}"
wget -O "${BISHENG_NAME}" "${BISHENG_URL}" && chmod a+x "${BISHENG_NAME}" && "./${BISHENG_NAME}" --install && rm "${BISHENG_NAME}"
### Install sgl-kernel-npu ### Install sgl-kernel-npu
SGL_KERNEL_NPU_TAG="20250901" SGL_KERNEL_NPU_TAG="20250913"
git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG} git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG}
(cd sgl-kernel-npu && bash ./build.sh -a deepep && pip install output/deep_ep*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so) (cd sgl-kernel-npu && bash ./build.sh && pip install output/deep_ep*.whl output/sgl_kernel_npu*.whl && cd "$(pip show deep-ep | grep -E '^Location:' | awk '{print $2}')" && ln -s deep_ep/deep_ep_cpp*.so)
### Install SGLang ### Install SGLang
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment