Unverified Commit 137e75da authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

[Feature] Optimize DeepSeek's DeepEP on Ascend NPU (#8355)


Co-authored-by: default avatarronnie_zheng <zl19940307@163.com>
Co-authored-by: default avatarHexq0210 <hexq0809521@gmail.com>
parent 52e1f52f
...@@ -50,6 +50,8 @@ from sglang.srt.utils import ( ...@@ -50,6 +50,8 @@ from sglang.srt.utils import (
supports_custom_op, supports_custom_op,
) )
_is_npu = is_npu()
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
...@@ -591,7 +593,7 @@ class GroupCoordinator: ...@@ -591,7 +593,7 @@ class GroupCoordinator:
) )
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
if not supports_custom_op(): if _is_npu or not supports_custom_op():
self._all_gather_into_tensor(output, input) self._all_gather_into_tensor(output, input)
else: else:
torch.ops.sglang.reg_all_gather_into_tensor( torch.ops.sglang.reg_all_gather_into_tensor(
...@@ -1127,7 +1129,7 @@ def init_model_parallel_group( ...@@ -1127,7 +1129,7 @@ def init_model_parallel_group(
group_ranks=group_ranks, group_ranks=group_ranks,
local_rank=local_rank, local_rank=local_rank,
torch_distributed_backend=backend, torch_distributed_backend=backend,
use_pynccl=not is_npu(), use_pynccl=not _is_npu,
use_pymscclpp=use_mscclpp_allreduce, use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce, use_custom_allreduce=use_custom_allreduce,
use_hpu_communicator=True, use_hpu_communicator=True,
......
...@@ -75,6 +75,9 @@ class AscendAttnBackend(AttentionBackend): ...@@ -75,6 +75,9 @@ class AscendAttnBackend(AttentionBackend):
) )
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend( def forward_extend(
self, self,
q, q,
......
...@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, ...@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput, DeepEPLLOutput,
DeepEPNormalOutput, DeepEPNormalOutput,
DispatchOutput, DispatchOutput,
...@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE): ...@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE):
return_recv_hook=True, return_recv_hook=True,
) )
if self.deepep_mode.enable_low_latency(): if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm
assert ( assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm" ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
...@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE): ...@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE):
) )
# the last one is invalid rank_id # the last one is invalid rank_id
self.expert_mask[:-1] = 1 self.expert_mask[:-1] = 1
else: elif not _is_npu:
self.w13_weight_fp8 = ( self.w13_weight_fp8 = (
self.w13_weight, self.w13_weight,
( (
...@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE): ...@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE):
if _use_aiter: if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output) return self.forward_aiter(dispatch_output)
if _is_npu:
return self.forward_npu(dispatch_output)
if dispatch_output.format.is_deepep_normal(): if dispatch_output.format.is_deepep_normal():
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output) return self.forward_deepgemm_contiguous(dispatch_output)
...@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE): ...@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE):
return down_output return down_output
def forward_npu(
self,
dispatch_output: DeepEPLLOutput,
):
if TYPE_CHECKING:
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
assert self.quant_method is not None
assert self.activation == "silu"
# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype = torch.bfloat16
pertoken_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list_type = 1
seg_indptr = seg_indptr.to(torch.int64)
import torch_npu
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=output_dtype,
)[0]
return hidden_states
def get_moe_impl_class(): def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if global_server_args_dict["moe_a2a_backend"].is_deepep():
......
...@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( ...@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
is_hip,
is_npu,
load_json_config,
)
_is_npu = is_npu()
try: try:
from deep_ep import Buffer, Config from deep_ep import Buffer, Config
from sglang.srt.layers.quantization.fp8_kernel import ( if not _is_npu:
sglang_per_token_group_quant_fp8, from sglang.srt.layers.quantization.fp8_kernel import (
) sglang_per_token_group_quant_fp8,
)
use_deepep = True use_deepep = True
except ImportError: except ImportError:
...@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple): ...@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
return DispatchOutputFormat.deepep_ll return DispatchOutputFormat.deepep_ll
class AscendDeepEPLLOutput(NamedTuple):
"""AscendDeepEP low latency dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
masked_m: torch.Tensor
seg_indptr: torch.Tensor
expected_m: int
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll
assert isinstance(DeepEPNormalOutput, DispatchOutput) assert isinstance(DeepEPNormalOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput) assert isinstance(DeepEPLLOutput, DispatchOutput)
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
class DeepEPDispatchMode(IntEnum): class DeepEPDispatchMode(IntEnum):
...@@ -150,19 +175,20 @@ class DeepEPBuffer: ...@@ -150,19 +175,20 @@ class DeepEPBuffer:
else: else:
raise NotImplementedError raise NotImplementedError
total_num_sms = torch.cuda.get_device_properties( if not _is_npu:
device="cuda" total_num_sms = torch.cuda.get_device_properties(
).multi_processor_count device="cuda"
if ( ).multi_processor_count
(deepep_mode != DeepEPMode.LOW_LATENCY) if (
and not global_server_args_dict["enable_two_batch_overlap"] (deepep_mode != DeepEPMode.LOW_LATENCY)
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2) and not global_server_args_dict["enable_two_batch_overlap"]
): and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
logger.warning( ):
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. " logger.warning(
f"This may result in highly suboptimal performance. " f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
f"Consider using --deepep-config to change the behavior." f"This may result in highly suboptimal performance. "
) f"Consider using --deepep-config to change the behavior."
)
cls._buffer = Buffer( cls._buffer = Buffer(
group, group,
...@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m masked_m
) )
return DeepEPLLOutput( if _is_npu:
hidden_states, deepep_output = AscendDeepEPLLOutput(
topk_idx, hidden_states,
topk_weights, topk_idx,
masked_m, topk_weights,
expected_m, masked_m,
) self.handle[1],
expected_m,
)
else:
deepep_output = DeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
)
return deepep_output
def _dispatch_core( def _dispatch_core(
self, self,
......
...@@ -250,10 +250,11 @@ class TopK(CustomOp): ...@@ -250,10 +250,11 @@ class TopK(CustomOp):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256: if global_num_experts == 256:
router_logits = router_logits.to(torch.float32)
return torch_npu.npu_moe_gating_top_k( return torch_npu.npu_moe_gating_top_k(
router_logits, router_logits,
k=self.top_k, k=self.top_k,
bias=self.correction_bias, bias=self.correction_bias.to(torch.float32),
k_group=self.topk_group, k_group=self.topk_group,
group_count=self.num_expert_group, group_count=self.num_expert_group,
group_select_mode=1, group_select_mode=1,
......
...@@ -3,7 +3,18 @@ from __future__ import annotations ...@@ -3,7 +3,18 @@ from __future__ import annotations
import importlib import importlib
import sys import sys
from types import MappingProxyType from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
cast,
)
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func): ...@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous(): if not x.is_contiguous():
x = x.contiguous() x = x.contiguous()
original_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None: if residual is not None:
x = x + residual.to(torch.float32) out, _, residual_out = torch_npu.npu_add_rms_norm(
residual = x.to(original_dtype) residual, x, self.weight.data, self.variance_epsilon
)
x = ( out = out + self.bias
torch_npu.npu_rms_norm( return out.to(x.dtype), residual_out
x, self.weight.to(torch.float32), self.variance_epsilon
)[0]
+ self.bias
)
if residual is None: out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
return x.to(original_dtype) out = out + self.bias
return x.to(original_dtype), residual return out.to(x.dtype)
return _rmsnorm_forward_oot return _rmsnorm_forward_oot
...@@ -571,8 +576,10 @@ class NPU_W8A8LinearMethodImpl: ...@@ -571,8 +576,10 @@ class NPU_W8A8LinearMethodImpl:
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor: ) -> torch.Tensor:
# To prevent import loops
from sglang.srt.layers.linear import RowParallelLinear
original_dtype = x.dtype original_dtype = x.dtype
if original_dtype != torch.int8: if original_dtype != torch.int8:
x = torch_npu.npu_quantize( x = torch_npu.npu_quantize(
...@@ -583,8 +590,12 @@ class NPU_W8A8LinearMethodImpl: ...@@ -583,8 +590,12 @@ class NPU_W8A8LinearMethodImpl:
-1, -1,
True, True,
) )
# Only fuse bias add into GEMM for rank 0 (this ensures that
quant_bias = layer.quant_bias if tp_rank == 0 else None # bias will not get added more than once in Attention TP>1 case)
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
quant_bias = None
else:
quant_bias = layer.quant_bias
return torch_npu.npu_quant_matmul( return torch_npu.npu_quant_matmul(
x, x,
layer.weight, layer.weight,
...@@ -651,13 +662,21 @@ class NPU_W8A8LinearMethodMTImpl: ...@@ -651,13 +662,21 @@ class NPU_W8A8LinearMethodMTImpl:
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor: ) -> torch.Tensor:
# To prevent import loops
from sglang.srt.layers.linear import RowParallelLinear
original_dtype = x.dtype original_dtype = x.dtype
if original_dtype != torch.int8: if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.input_scale, layer.input_offset) x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
quant_bias = layer.quant_bias if tp_rank == 0 else None # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
quant_bias = None
else:
quant_bias = layer.quant_bias
return ops.quant_matmul( return ops.quant_matmul(
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
) )
...@@ -737,11 +756,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase): ...@@ -737,11 +756,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias) return self.quant_method.apply(layer, x, bias)
...@@ -780,7 +794,6 @@ class NPU_W8A8DynamicLinearMethodImpl: ...@@ -780,7 +794,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
tp_rank: Optional[int] = 0, tp_rank: Optional[int] = 0,
) -> torch.Tensor: ) -> torch.Tensor:
original_dtype = x.dtype original_dtype = x.dtype
# use ATB quantize
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
return torch_npu.npu_quant_matmul( return torch_npu.npu_quant_matmul(
quant_out, quant_out,
...@@ -863,11 +876,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase): ...@@ -863,11 +876,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias) return self.quant_method.apply(layer, x, bias)
......
...@@ -680,7 +680,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -680,7 +680,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
) )
# Re-dispatch # Re-dispatch
if _is_hip or _is_npu: if _is_hip:
self._forward_method = self.forward_native self._forward_method = self.forward_native
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
...@@ -765,6 +765,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -765,6 +765,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
key = key_rot key = key_rot
return query.to(dtype), key.to(dtype) return query.to(dtype), key.to(dtype)
def forward_npu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
# and generalization to more scenarios will be supported in the future.
if query.shape[1] * query.shape[2] > 4096:
return self.forward_native(positions, query, key, offsets)
num_tokens = query.shape[0]
rotary_mode = "half" if self.is_neox_style else "interleave"
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
query_rot, key_rot = torch_npu.npu_mrope(
torch.add(positions, offsets) if offsets is not None else positions,
query_rot.reshape(num_tokens, -1),
key_rot.reshape(num_tokens, -1),
self.cos_sin_cache,
self.rotary_dim,
mrope_section=[0, 0, 0],
rotary_mode=rotary_mode,
)
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def forward_cpu( def forward_cpu(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment