Unverified Commit d27a6f70 authored by Even Zhou's avatar Even Zhou Committed by GitHub
Browse files

[Feature] Add MLAProcess for DeepSeek MLA on NPU (#10130)

parent 0753ef83
...@@ -118,7 +118,7 @@ git clone https://github.com/sgl-project/sglang.git ...@@ -118,7 +118,7 @@ git clone https://github.com/sgl-project/sglang.git
cd sglang/docker cd sglang/docker
# Build the docker image # Build the docker image
docker build -t sglang-npu:main -f Dockerfile.npu . docker build -t <image_name> -f Dockerfile.npu .
alias drun='docker run -it --rm --privileged --network=host --ipc=host --shm-size=16g \ alias drun='docker run -it --rm --privileged --network=host --ipc=host --shm-size=16g \
--device=/dev/davinci0 --device=/dev/davinci1 --device=/dev/davinci2 --device=/dev/davinci3 \ --device=/dev/davinci0 --device=/dev/davinci1 --device=/dev/davinci2 --device=/dev/davinci3 \
...@@ -132,7 +132,7 @@ alias drun='docker run -it --rm --privileged --network=host --ipc=host --shm-siz ...@@ -132,7 +132,7 @@ alias drun='docker run -it --rm --privileged --network=host --ipc=host --shm-siz
--volume /var/queue_schedule:/var/queue_schedule --volume ~/.cache/:/root/.cache/' --volume /var/queue_schedule:/var/queue_schedule --volume ~/.cache/:/root/.cache/'
drun --env "HF_TOKEN=<secret>" \ drun --env "HF_TOKEN=<secret>" \
sglang-npu:main \ <image_name> \
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend ascend --host 0.0.0.0 --port 30000 python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend ascend --host 0.0.0.0 --port 30000
``` ```
...@@ -149,7 +149,7 @@ Prefill: ...@@ -149,7 +149,7 @@ Prefill:
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export ASCEND_MF_STORE_URL="tcp://<PREFILL_HOST_IP>:<PORT>" export ASCEND_MF_STORE_URL="tcp://<PREFILL_HOST_IP>:<PORT>"
drun sglang-npu:main \ drun <image_name> \
python3 -m sglang.launch_server --model-path State_Cloud/DeepSeek-R1-bf16-hfd-w8a8 \ python3 -m sglang.launch_server --model-path State_Cloud/DeepSeek-R1-bf16-hfd-w8a8 \
--trust-remote-code \ --trust-remote-code \
--attention-backend ascend \ --attention-backend ascend \
...@@ -174,8 +174,9 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True ...@@ -174,8 +174,9 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export ASCEND_MF_STORE_URL="tcp://<PREFILL_HOST_IP>:<PORT>" export ASCEND_MF_STORE_URL="tcp://<PREFILL_HOST_IP>:<PORT>"
export HCCL_BUFFSIZE=200 export HCCL_BUFFSIZE=200
export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=24 export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=24
export SGLANG_NPU_USE_MLAPO=1
drun sglang-npu:main \ drun <image_name> \
python3 -m sglang.launch_server --model-path State_Cloud/DeepSeek-R1-bf16-hfd-w8a8 \ python3 -m sglang.launch_server --model-path State_Cloud/DeepSeek-R1-bf16-hfd-w8a8 \
--trust-remote-code \ --trust-remote-code \
--attention-backend ascend \ --attention-backend ascend \
...@@ -198,7 +199,7 @@ drun sglang-npu:main \ ...@@ -198,7 +199,7 @@ drun sglang-npu:main \
Mini_LB: Mini_LB:
```shell ```shell
drun sglang-npu:main \ drun <image_name> \
python -m sglang.srt.disaggregation.launch_lb \ python -m sglang.srt.disaggregation.launch_lb \
--prefill http://<PREFILL_HOST_IP>:8000 \ --prefill http://<PREFILL_HOST_IP>:8000 \
--decode http://<DECODE_HOST_IP>:8001 \ --decode http://<DECODE_HOST_IP>:8001 \
......
...@@ -9,6 +9,7 @@ from torch.nn.functional import scaled_dot_product_attention ...@@ -9,6 +9,7 @@ from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
...@@ -401,7 +402,7 @@ class AscendAttnBackend(AttentionBackend): ...@@ -401,7 +402,7 @@ class AscendAttnBackend(AttentionBackend):
antiquant_scale=None, antiquant_scale=None,
sparse_mode=0, sparse_mode=0,
) )
output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device) output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
...@@ -437,6 +438,10 @@ class AscendAttnBackend(AttentionBackend): ...@@ -437,6 +438,10 @@ class AscendAttnBackend(AttentionBackend):
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
): ):
if is_mla_preprocess_enabled():
# MLAPO does saving kv_cache
save_kv_cache = False
if self.graph_mode: if self.graph_mode:
return self.forward_decode_graph( return self.forward_decode_graph(
q, q,
......
import torch
import torch.nn.functional as F
from sglang.srt.utils import get_bool_env_var, is_npu
_is_npu = is_npu()
_ENABLE_MLA_PREPROCESS_FLAG = get_bool_env_var("SGLANG_NPU_USE_MLAPO")
_NPU_FORMAT_NZ = 29
def is_mla_preprocess_enabled() -> bool:
return _is_npu and _ENABLE_MLA_PREPROCESS_FLAG
if is_mla_preprocess_enabled():
import sgl_kernel_npu
import torch_npu
torch.npu.config.allow_internal_format = True
torch.npu.set_compile_mode(jit_compile=False)
def round_up(val: int, align: int) -> int:
if align == 0:
return 0
return -(val // -align) * align
def transdata(nd_mat, block_size: tuple = (16, 16)):
r = round_up(nd_mat.shape[0], block_size[0])
c = round_up(nd_mat.shape[1], block_size[1])
r_pad = r - nd_mat.shape[0]
c_pad = c - nd_mat.shape[1]
nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad)))
nz_mat = torch.permute(
torch.reshape(
nd_mat,
(r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
),
[2, 0, 1, 3],
)
nz_mat = torch.reshape(
nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])
)
return nz_mat
def trans_rope_weight(weight, rope_dim):
weight_1 = weight[..., -rope_dim::2, :].contiguous()
weight_2 = weight[..., -rope_dim + 1 :: 2, :].contiguous()
weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2)
return weight.contiguous()
class NPUFusedMLAPreprocess(torch.nn.Module):
def __init__(
self,
fused_qkv_a_proj_with_mqa,
q_a_layernorm,
kv_a_layernorm,
q_b_proj,
w_kc,
rotary_emb,
layer_id,
num_local_heads,
qk_nope_head_dim,
qk_rope_head_dim,
):
super().__init__()
self.qkv_a_proj = fused_qkv_a_proj_with_mqa
self.q_a_layernorm = q_a_layernorm
self.kv_a_layernorm = kv_a_layernorm
self.q_b_proj = q_b_proj
self.w_kc = w_kc.contiguous()
self.rotary_emb = rotary_emb
self.layer_id = layer_id
self.has_preprocess_weights = False
self.q_lora_rank = self.q_b_proj.input_size # 1536
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
self.num_local_heads = num_local_heads # tp
self.qk_nope_head_dim = qk_nope_head_dim # 128
self.qk_rope_head_dim = qk_rope_head_dim # 64
def preprocess_weights(self, hidden_states):
self.dummy = torch.empty(
(hidden_states.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
self.qkv_a_proj_input_offset = self.qkv_a_proj.input_offset.to(dtype=torch.int8)
self.q_b_proj_input_offset = self.q_b_proj.input_offset.to(dtype=torch.int8)
# matmul_0 weight [7168, 2112]
fused_qkv_a_proj_with_mqa_weight_q = self.qkv_a_proj.weight.data[
:, : self.q_lora_rank
].clone() # [7168, 1536]
fused_qkv_a_proj_with_mqa_weight_kv = self.qkv_a_proj.weight.data[
:, self.q_lora_rank :
].clone() # [7168, 576]
# rope fit
fused_qkv_a_proj_with_mqa_weight_kv_t = (
fused_qkv_a_proj_with_mqa_weight_kv.t().contiguous()
)
fused_qkv_a_proj_with_mqa_weight_kv_t = trans_rope_weight(
fused_qkv_a_proj_with_mqa_weight_kv_t, self.qk_rope_head_dim
)
fused_qkv_a_proj_with_mqa_weight_kv = (
fused_qkv_a_proj_with_mqa_weight_kv_t.t().contiguous()
)
# cat nz
fused_qkv_a_proj_with_mqa_weight_new = torch.cat(
(fused_qkv_a_proj_with_mqa_weight_kv, fused_qkv_a_proj_with_mqa_weight_q),
dim=-1,
)
fused_qkv_a_proj_with_mqa_weight = (
fused_qkv_a_proj_with_mqa_weight_new.t().contiguous()
)
fused_qkv_a_proj_with_mqa_weight_nz = (
transdata(fused_qkv_a_proj_with_mqa_weight, block_size=(16, 32))
.unsqueeze(0)
.contiguous()
)
self.qkv_a_proj_weight_nz = torch_npu.npu_format_cast(
fused_qkv_a_proj_with_mqa_weight_nz, _NPU_FORMAT_NZ
)
# matmul_0 deq_scale [2112]
fused_qkv_a_proj_with_mqa_deq_scale_q = self.qkv_a_proj.deq_scale.data[
: self.q_lora_rank
].clone() # [7168, 1536]
fused_qkv_a_proj_with_mqa_deq_scale_kv = self.qkv_a_proj.deq_scale.data[
self.q_lora_rank :
].clone() # [7168, 576]
# rope fit
fused_qkv_a_proj_with_mqa_deq_scale_kv = (
fused_qkv_a_proj_with_mqa_deq_scale_kv.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1
).contiguous()
)
fused_qkv_a_proj_with_mqa_deq_scale_kv = trans_rope_weight(
fused_qkv_a_proj_with_mqa_deq_scale_kv, self.qk_rope_head_dim
)
fused_qkv_a_proj_with_mqa_deq_scale_kv = (
fused_qkv_a_proj_with_mqa_deq_scale_kv.view(
self.kv_lora_rank + self.qk_rope_head_dim
).contiguous()
)
self.qkv_a_proj_deq_scale_kvq = torch.cat(
(
fused_qkv_a_proj_with_mqa_deq_scale_kv,
fused_qkv_a_proj_with_mqa_deq_scale_q,
),
dim=-1,
)
# matmul_0 quant_bias [2112]
fused_qkv_a_proj_with_mqa_quant_bias_q = self.qkv_a_proj.quant_bias.data[
: self.q_lora_rank
].clone() # [7168, 1536]
fused_qkv_a_proj_with_mqa_quant_bias_kv = self.qkv_a_proj.quant_bias.data[
self.q_lora_rank :
].clone() # [7168, 576]
# rope fit
fused_qkv_a_proj_with_mqa_quant_bias_kv = (
fused_qkv_a_proj_with_mqa_quant_bias_kv.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1
).contiguous()
)
fused_qkv_a_proj_with_mqa_quant_bias_kv = trans_rope_weight(
fused_qkv_a_proj_with_mqa_quant_bias_kv, self.qk_rope_head_dim
)
fused_qkv_a_proj_with_mqa_quant_bias_kv = (
fused_qkv_a_proj_with_mqa_quant_bias_kv.view(
self.kv_lora_rank + self.qk_rope_head_dim
).contiguous()
)
self.qkv_a_proj_quant_bias_kvq = torch.cat(
(
fused_qkv_a_proj_with_mqa_quant_bias_kv,
fused_qkv_a_proj_with_mqa_quant_bias_q,
),
dim=-1,
)
# matmul_1 weight [1536, num_head * 192]
q_b_proj_weight = self.q_b_proj.weight.data.clone()
q_b_proj_weight = q_b_proj_weight.t().reshape(
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
)
q_b_proj_weight = trans_rope_weight(q_b_proj_weight, self.qk_rope_head_dim)
q_b_proj_weight = q_b_proj_weight.reshape(
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1
)
q_b_proj_weight_nz = (
transdata(q_b_proj_weight, block_size=(16, 32)).unsqueeze(0).contiguous()
)
self.q_b_proj_weight_nz = torch_npu.npu_format_cast(
q_b_proj_weight_nz, _NPU_FORMAT_NZ
)
# matmul_1 deq_scale [num_head * 192]
q_b_proj_deq_scale = self.q_b_proj.deq_scale.data.clone()
q_b_proj_deq_scale = q_b_proj_deq_scale.reshape(
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
)
q_b_proj_deq_scale = trans_rope_weight(
q_b_proj_deq_scale, self.qk_rope_head_dim
)
self.q_b_proj_deq_scale = q_b_proj_deq_scale.reshape(
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)
)
# matmul_1 quant_bias [num_head * 192]
q_b_proj_quant_bias = self.q_b_proj.quant_bias.data.clone()
q_b_proj_quant_bias = q_b_proj_quant_bias.reshape(
self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1
)
q_b_proj_quant_bias = trans_rope_weight(
q_b_proj_quant_bias, self.qk_rope_head_dim
)
self.q_b_proj_quant_bias = q_b_proj_quant_bias.reshape(
self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)
)
def get_sin_cos(self, positions):
cos_sin = self.rotary_emb.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2)
sin = sin.repeat(1, 2)
return cos, sin
def get_kv_cache_and_cache_idx(self, forward_batch):
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_id)
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
return k_cache, v_cache, slot_mapping
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
input_dtype = hidden_states.dtype
if not self.has_preprocess_weights:
self.preprocess_weights(hidden_states)
self.has_preprocess_weights = True
self.dtype = hidden_states.dtype
cos, sin = self.get_sin_cos(positions)
k_cache, v_cache, slot_mapping = self.get_kv_cache_and_cache_idx(forward_batch)
q_nope_out = torch.empty(
(hidden_states.shape[0], self.w_kc.shape[0], k_cache.shape[-1]),
dtype=input_dtype,
device=hidden_states.device,
)
q_rope_out = torch.empty(
(hidden_states.shape[0], self.w_kc.shape[0], v_cache.shape[-1]),
dtype=input_dtype,
device=hidden_states.device,
)
# TODO: dummy inputs to be removed
# https://github.com/sgl-project/sgl-kernel-npu/issues/78
torch.ops.npu.mla_preprocess(
hidden_states,
self.dummy,
self.dummy,
self.qkv_a_proj_weight_nz,
self.qkv_a_proj_deq_scale_kvq,
self.q_a_layernorm.weight,
self.q_a_layernorm.bias,
self.q_b_proj_weight_nz,
self.q_b_proj_deq_scale,
self.kv_a_layernorm.weight,
cos,
sin,
self.w_kc,
k_cache,
v_cache,
slot_mapping,
quant_scale0=self.qkv_a_proj.input_scale,
quant_offset0=self.qkv_a_proj_input_offset,
bias0=self.qkv_a_proj_quant_bias_kvq,
quant_scale1=self.q_b_proj.input_scale,
quant_offset1=self.q_b_proj_input_offset,
bias1=self.q_b_proj_quant_bias,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=q_nope_out,
kv_cache_out0=k_cache,
q_out1=q_rope_out,
kv_cache_out1=v_cache,
)
return (
q_rope_out,
v_cache,
q_nope_out,
k_cache,
forward_batch,
zero_allocator,
positions,
)
...@@ -782,27 +782,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -782,27 +782,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern, num_tokens, num_q_heads, _ = query.shape
# and generalization to more scenarios will be supported in the future. num_k_heads = key.shape[1]
if query.shape[1] * query.shape[2] > 4096:
return self.forward_native(positions, query, key, offsets)
num_tokens = query.shape[0]
rotary_mode = "half" if self.is_neox_style else "interleave"
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
cos_sin = self.cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)
# Reshape to [batchsize, head_dim, seq, rotary_dim]
cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
query_rot = query[..., : self.rotary_dim] query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size: if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :] query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :]
query_rot, key_rot = torch_npu.npu_mrope( query_rot = torch_npu.npu_interleave_rope(
torch.add(positions, offsets) if offsets is not None else positions, query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim),
query_rot.reshape(num_tokens, -1), cos,
key_rot.reshape(num_tokens, -1), sin,
self.cos_sin_cache, )
self.rotary_dim, key_rot = torch_npu.npu_interleave_rope(
mrope_section=[0, 0, 0], key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim),
rotary_mode=rotary_mode, cos,
sin,
) )
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim) query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim) key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
......
...@@ -43,6 +43,10 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation ...@@ -43,6 +43,10 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
NPUFusedMLAPreprocess,
is_mla_preprocess_enabled,
)
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
LayerCommunicator, LayerCommunicator,
LayerScatterModes, LayerScatterModes,
...@@ -1177,6 +1181,12 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1177,6 +1181,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self.weight_block_size = ( self.weight_block_size = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
) )
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
if self.is_mla_preprocess_enabled:
assert (
quant_config.get_name() == "w8a8_int8"
), "MLA Preprocess only works with W8A8Int8"
self.mla_preprocess = None
def dispatch_attn_forward_method( def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
...@@ -1263,9 +1273,28 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1263,9 +1273,28 @@ class DeepseekV2AttentionMLA(nn.Module):
positions, hidden_states, forward_batch, zero_allocator positions, hidden_states, forward_batch, zero_allocator
) )
elif attn_forward_method == AttnForwardMethod.MLA: elif attn_forward_method == AttnForwardMethod.MLA:
inner_state = self.forward_absorb_prepare( if not self.is_mla_preprocess_enabled:
positions, hidden_states, forward_batch, zero_allocator inner_state = self.forward_absorb_prepare(
) positions, hidden_states, forward_batch, zero_allocator
)
else:
# TODO(iforgetmyname): to be separated as a standalone func
if self.mla_preprocess is None:
self.mla_preprocess = NPUFusedMLAPreprocess(
self.fused_qkv_a_proj_with_mqa,
self.q_a_layernorm,
self.kv_a_layernorm,
self.q_b_proj,
self.w_kc,
self.rotary_emb,
self.layer_id,
self.num_local_heads,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
)
inner_state = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
inner_state = self.forward_absorb_fused_mla_rope_prepare( inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch, zero_allocator positions, hidden_states, forward_batch, zero_allocator
......
...@@ -174,6 +174,8 @@ def is_blackwell(): ...@@ -174,6 +174,8 @@ def is_blackwell():
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def is_sm100_supported(device=None) -> bool: def is_sm100_supported(device=None) -> bool:
if not is_cuda_alike():
return False
return (torch.cuda.get_device_capability(device)[0] == 10) and ( return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8" torch.version.cuda >= "12.8"
) )
...@@ -181,6 +183,8 @@ def is_sm100_supported(device=None) -> bool: ...@@ -181,6 +183,8 @@ def is_sm100_supported(device=None) -> bool:
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def is_sm90_supported(device=None) -> bool: def is_sm90_supported(device=None) -> bool:
if not is_cuda_alike():
return False
return (torch.cuda.get_device_capability(device)[0] == 9) and ( return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3" torch.version.cuda >= "12.3"
) )
......
...@@ -60,6 +60,7 @@ class TestAscendDeepEP(CustomTestCase): ...@@ -60,6 +60,7 @@ class TestAscendDeepEP(CustomTestCase):
cls.extra_envs = { cls.extra_envs = {
"HCCL_BUFFSIZE": "500", "HCCL_BUFFSIZE": "500",
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "32", "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "32",
"SGLANG_NPU_USE_MLAPO": "1",
} }
os.environ.update(cls.extra_envs) os.environ.update(cls.extra_envs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment