"examples/mxnet/vscode:/vscode.git/clone" did not exist on "3d6548432370060f4e231e0d0edb5df4b0318d12"
Unverified Commit 137e75da authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

[Feature] Optimize DeepSeek's DeepEP on Ascend NPU (#8355)


Co-authored-by: default avatarronnie_zheng <zl19940307@163.com>
Co-authored-by: default avatarHexq0210 <hexq0809521@gmail.com>
parent 52e1f52f
......@@ -50,6 +50,8 @@ from sglang.srt.utils import (
supports_custom_op,
)
_is_npu = is_npu()
@dataclass
class GraphCaptureContext:
......@@ -591,7 +593,7 @@ class GroupCoordinator:
)
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
if not supports_custom_op():
if _is_npu or not supports_custom_op():
self._all_gather_into_tensor(output, input)
else:
torch.ops.sglang.reg_all_gather_into_tensor(
......@@ -1127,7 +1129,7 @@ def init_model_parallel_group(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not is_npu(),
use_pynccl=not _is_npu,
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_hpu_communicator=True,
......
......@@ -75,6 +75,9 @@ class AscendAttnBackend(AttentionBackend):
)
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(
self,
q,
......
......@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
......@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE):
return_recv_hook=True,
)
if self.deepep_mode.enable_low_latency():
if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
......@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE):
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
else:
elif not _is_npu:
self.w13_weight_fp8 = (
self.w13_weight,
(
......@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE):
if _use_aiter:
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if _is_npu:
return self.forward_npu(dispatch_output)
if dispatch_output.format.is_deepep_normal():
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
......@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE):
return down_output
def forward_npu(
self,
dispatch_output: DeepEPLLOutput,
):
if TYPE_CHECKING:
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
assert self.quant_method is not None
assert self.activation == "silu"
# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype = torch.bfloat16
pertoken_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list_type = 1
seg_indptr = seg_indptr.to(torch.int64)
import torch_npu
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=seg_indptr,
output_dtype=output_dtype,
)[0]
return hidden_states
def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep():
......
......@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
is_hip,
is_npu,
load_json_config,
)
_is_npu = is_npu()
try:
from deep_ep import Buffer, Config
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
if not _is_npu:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deepep = True
except ImportError:
......@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
return DispatchOutputFormat.deepep_ll
class AscendDeepEPLLOutput(NamedTuple):
"""AscendDeepEP low latency dispatch output."""
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
topk_idx: torch.Tensor
topk_weights: torch.Tensor
masked_m: torch.Tensor
seg_indptr: torch.Tensor
expected_m: int
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll
assert isinstance(DeepEPNormalOutput, DispatchOutput)
assert isinstance(DeepEPLLOutput, DispatchOutput)
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
class DeepEPDispatchMode(IntEnum):
......@@ -150,19 +175,20 @@ class DeepEPBuffer:
else:
raise NotImplementedError
total_num_sms = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
if (
(deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"]
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
):
logger.warning(
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
f"This may result in highly suboptimal performance. "
f"Consider using --deepep-config to change the behavior."
)
if not _is_npu:
total_num_sms = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
if (
(deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"]
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
):
logger.warning(
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
f"This may result in highly suboptimal performance. "
f"Consider using --deepep-config to change the behavior."
)
cls._buffer = Buffer(
group,
......@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
masked_m
)
return DeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
)
if _is_npu:
deepep_output = AscendDeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
self.handle[1],
expected_m,
)
else:
deepep_output = DeepEPLLOutput(
hidden_states,
topk_idx,
topk_weights,
masked_m,
expected_m,
)
return deepep_output
def _dispatch_core(
self,
......
......@@ -250,10 +250,11 @@ class TopK(CustomOp):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
router_logits = router_logits.to(torch.float32)
return torch_npu.npu_moe_gating_top_k(
router_logits,
k=self.top_k,
bias=self.correction_bias,
bias=self.correction_bias.to(torch.float32),
k_group=self.topk_group,
group_count=self.num_expert_group,
group_select_mode=1,
......
......@@ -3,7 +3,18 @@ from __future__ import annotations
import importlib
import sys
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
cast,
)
import torch
from torch.nn.parameter import Parameter
......@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
original_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(original_dtype)
x = (
torch_npu.npu_rms_norm(
x, self.weight.to(torch.float32), self.variance_epsilon
)[0]
+ self.bias
)
out, _, residual_out = torch_npu.npu_add_rms_norm(
residual, x, self.weight.data, self.variance_epsilon
)
out = out + self.bias
return out.to(x.dtype), residual_out
if residual is None:
return x.to(original_dtype)
return x.to(original_dtype), residual
out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
out = out + self.bias
return out.to(x.dtype)
return _rmsnorm_forward_oot
......@@ -571,8 +576,10 @@ class NPU_W8A8LinearMethodImpl:
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
# To prevent import loops
from sglang.srt.layers.linear import RowParallelLinear
original_dtype = x.dtype
if original_dtype != torch.int8:
x = torch_npu.npu_quantize(
......@@ -583,8 +590,12 @@ class NPU_W8A8LinearMethodImpl:
-1,
True,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
quant_bias = None
else:
quant_bias = layer.quant_bias
return torch_npu.npu_quant_matmul(
x,
layer.weight,
......@@ -651,13 +662,21 @@ class NPU_W8A8LinearMethodMTImpl:
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
# To prevent import loops
from sglang.srt.layers.linear import RowParallelLinear
original_dtype = x.dtype
if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
quant_bias = layer.quant_bias if tp_rank == 0 else None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
quant_bias = None
else:
quant_bias = layer.quant_bias
return ops.quant_matmul(
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
)
......@@ -737,11 +756,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)
......@@ -780,7 +794,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
# use ATB quantize
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
return torch_npu.npu_quant_matmul(
quant_out,
......@@ -863,11 +876,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)
......
......@@ -680,7 +680,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
)
# Re-dispatch
if _is_hip or _is_npu:
if _is_hip:
self._forward_method = self.forward_native
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
......@@ -765,6 +765,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
key = key_rot
return query.to(dtype), key.to(dtype)
def forward_npu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
# and generalization to more scenarios will be supported in the future.
if query.shape[1] * query.shape[2] > 4096:
return self.forward_native(positions, query, key, offsets)
num_tokens = query.shape[0]
rotary_mode = "half" if self.is_neox_style else "interleave"
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
query_rot, key_rot = torch_npu.npu_mrope(
torch.add(positions, offsets) if offsets is not None else positions,
query_rot.reshape(num_tokens, -1),
key_rot.reshape(num_tokens, -1),
self.cos_sin_cache,
self.rotary_dim,
mrope_section=[0, 0, 0],
rotary_mode=rotary_mode,
)
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def forward_cpu(
self,
positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment