Unverified Commit cafebef1 authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

[NPU] bugfix for Qwen3-Next and performance update (#11969)

parent 73dfd2df
......@@ -73,6 +73,6 @@ jobs:
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance: false
build-args: |
SGLANG_KERNEL_NPU_TAG=20250926
SGLANG_KERNEL_NPU_TAG=20251030
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
......@@ -69,6 +69,6 @@ jobs:
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance: false
build-args: |
SGLANG_KERNEL_NPU_TAG=20250926
SGLANG_KERNEL_NPU_TAG=20251030
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
......@@ -12,7 +12,9 @@ import triton
import triton.language as tl
from einops import rearrange
from sglang.srt.utils import device_context
from sglang.srt.utils import device_context, is_npu
_is_npu = is_npu()
def rms_norm_ref(
......@@ -182,6 +184,10 @@ def _layer_norm_fwd(
return out, mean, rstd
if _is_npu:
from sgl_kernel_npu.fla.layernorm_gated import layer_norm_fwd_npu as _layer_norm_fwd
def rms_norm_gated(
*,
x,
......
......@@ -13,16 +13,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.attention.mamba.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn as causal_conv1d_fn_triton,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_update as causal_conv1d_update_triton,
)
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
from sglang.srt.layers.attention.mamba.ops import (
......@@ -40,7 +30,26 @@ from sglang.srt.model_loader.weight_utils import (
composed_weight_loader,
sharded_weight_loader,
)
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
if is_cuda():
from sglang.srt.layers.attention.mamba.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn as causal_conv1d_fn_triton,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_update as causal_conv1d_update_triton,
)
elif is_npu():
from sgl_kernel_npu.mamba.causal_conv1d import (
causal_conv1d_fn_npu as causal_conv1d_fn,
)
from sgl_kernel_npu.mamba.causal_conv1d import (
causal_conv1d_update_npu as causal_conv1d_update,
)
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
......
......@@ -314,16 +314,41 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
use_grouped_topk = self.topk_config.use_grouped_topk
torch_native = self.topk_config.torch_native
renormalize = self.topk_config.renormalize
if not use_grouped_topk and not torch_native:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
router_logits,
k=self.topk_config.top_k,
)
topk_weights = topk_weights.to(torch.float32)
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if self.topk_config.num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if expert_location_dispatch_info is not None:
topk_ids = topk_ids_logical_to_physical(
topk_ids, expert_location_dispatch_info
)
get_global_expert_distribution_recorder().on_select_experts(
topk_ids=topk_ids
)
return StandardTopKOutput(topk_weights, topk_ids, _)
if use_grouped_topk and not torch_native and router_logits.shape[-1] == 256:
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32)
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
router_logits.to(torch.float32),
k=self.topk_config.top_k,
bias=self.topk_config.correction_bias.to(torch.float32),
k_group=self.topk_config.topk_group,
......@@ -335,7 +360,7 @@ class TopK(CustomOp):
eps=float(1e-20),
)
if self.topk_config.renormalize:
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if self.topk_config.num_fused_shared_experts == 0
......
......@@ -478,6 +478,13 @@ class Qwen3GatedDeltaNet(nn.Module):
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
# Add padding for DP-Attn
if is_dp_attention_enabled():
core_attn_out_pad = torch.zeros_like(z)
core_attn_out_pad[: core_attn_out.shape[0], :] = core_attn_out
core_attn_out = core_attn_out_pad
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
......
......@@ -59,7 +59,7 @@ wget -O "${BISHENG_NAME}" "${BISHENG_URL}" && chmod a+x "${BISHENG_NAME}" && "./
### Install sgl-kernel-npu
SGL_KERNEL_NPU_TAG="20250926"
SGL_KERNEL_NPU_TAG="20251030"
git clone --depth 1 https://github.com/sgl-project/sgl-kernel-npu.git --branch ${SGL_KERNEL_NPU_TAG}
# pin wheel to 0.45.1 ref: https://github.com/pypa/wheel/issues/662
pip install wheel==0.45.1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment