Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
......@@ -127,8 +127,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"disable_chunked_prefix_cache"
]
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
......@@ -219,7 +217,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
......@@ -230,9 +228,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info,
)
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
# Custom fast-path for decode/idle.
# Capture with full width so future longer sequences are safe during replay
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
......@@ -275,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
):
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
......@@ -287,10 +282,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu,
)
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
metadata = self.decode_cuda_graph_metadata[bs]
# Update block indices for new sequences.
......@@ -341,10 +332,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
cum_seq_lens_q,
seq_lens,
)
elif (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
):
elif forward_batch.forward_mode.is_decode_or_idle():
bs = forward_batch.batch_size
# Get maximum sequence length.
......@@ -353,19 +341,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
else:
max_seq = forward_batch.seq_lens.max().item()
seq_lens = forward_batch.seq_lens
if forward_batch.forward_mode.is_target_verify():
max_seq = max_seq + self.num_draft_tokens
seq_lens = seq_lens + self.num_draft_tokens
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
seq_lens,
seq_lens.device,
forward_batch.seq_lens,
forward_batch.seq_lens.device,
)
max_seq_len_val = int(max_seq)
......@@ -505,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
q_rope_reshaped = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
else:
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
else:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
......@@ -568,134 +553,84 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if forward_batch.forward_mode.is_draft_extend():
):
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
# Save KV cache if requested
if save_kv_cache:
assert (
k is not None and k_rope is not None
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
if q_rope is not None:
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
q = _concat_mla_absorb_q_general(q, q_rope)
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
if forward_batch.forward_mode.is_target_verify():
metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
or self.forward_decode_metadata
)
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
bs = forward_batch.batch_size
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
seq_lens = (
forward_batch.seq_lens.to(torch.int32)
+ forward_batch.spec_info.draft_token_num
)
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
# TODO may use `mla_rope_quantize_fp8` fusion
q = q.to(self.data_type)
assert kv_cache.dtype == self.data_type
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=metadata.block_kv_indices,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=bmm1_scale,
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
if forward_batch.attn_attend_prefix_cache is None:
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
# Reshape output directly without slicing
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output
if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert q_rope is None
assert k_rope is None
chunk_idx = forward_batch.prefix_chunk_idx
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
if not forward_batch.attn_attend_prefix_cache:
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
seq_lens=self.forward_prefill_metadata.seq_lens,
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
max_kv_len=self.forward_prefill_metadata.max_seq_len,
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=-1.0,
o_sf_scale=1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
enable_pdl=False,
is_causal=False,
return_lse=True,
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
is_causal=True,
return_lse=forward_batch.mha_return_lse,
)
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=self.forward_prefill_metadata.seq_lens,
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=self.forward_prefill_metadata.max_seq_len,
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
enable_pdl=False,
is_causal=True,
return_lse=forward_batch.mha_return_lse,
)
else:
if not (
forward_batch.attn_attend_prefix_cache is not None
and forward_batch.mha_return_lse
):
output = super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
else:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert q_rope is None
assert k_rope is None
chunk_idx = forward_batch.prefix_chunk_idx
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=-1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
enable_pdl=False,
is_causal=False,
return_lse=True,
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
)
return output
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
......@@ -713,10 +648,3 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
def _concat_mla_absorb_q_general(q_nope, q_rope):
if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
return concat_mla_absorb_q(q_nope, q_rope)
else:
return torch.cat([q_nope, q_rope], dim=-1)
......@@ -16,19 +16,14 @@ from sglang.srt.utils import (
get_device_capability,
is_blackwell,
is_cuda,
is_npu,
print_info_once,
)
_is_cuda = is_cuda()
_is_npu = is_npu()
if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func
if _is_npu:
import torch_npu
from sglang.srt.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
......@@ -336,63 +331,10 @@ class VisionFlash3Attention(nn.Module):
return output
class VisionAscendAttention(nn.Module):
def __init__(
self,
**kwargs,
):
if not _is_npu:
raise Exception("VisionAscendAttention is only available for ascend npu")
super().__init__()
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
bsz: int,
seq_len: int,
**kwargs,
) -> torch.Tensor:
r"""
Args:
cu_seqlens: [b]
Returns:
[b * s, h, head_size]
"""
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
if seq_lens.is_npu:
# cu_seqlens must be on cpu because of operator restriction
seq_lens = seq_lens.to("cpu")
_, num_heads, head_size = q.shape
num_kv_heads = k.shape[1]
output = torch.empty_like(q)
# operator requires pta version >= 2.5.1
torch_npu._npu_flash_attention_unpad(
query=q,
key=k,
value=v,
seq_len=seq_lens.to(torch.int32),
scale_value=head_size**-0.5,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
out=output,
)
return output
QKV_BACKEND_IMPL = {
"triton_attn": VisionTritonAttention,
"sdpa": VisionSdpaAttention,
"fa3": VisionFlash3Attention,
"ascend_attn": VisionAscendAttention,
}
......
......@@ -50,7 +50,6 @@ from sglang.srt.utils import (
is_hip,
is_sm90_supported,
is_sm100_supported,
prepare_weight_cache,
)
_is_flashinfer_available = is_flashinfer_available()
......@@ -276,11 +275,7 @@ class LayerCommunicator:
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
cache=None,
):
if cache is not None:
self._context.cache = cache
return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states,
residual=residual,
......@@ -354,7 +349,6 @@ class CommunicateContext:
attn_tp_size: int
attn_dp_size: int
tp_size: int
cache = None
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b]
......@@ -539,8 +533,6 @@ class CommunicateWithAllReduceAndLayerNormFn:
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if context.cache is not None:
_ = prepare_weight_cache(hidden_states, context.cache)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
......
......@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
assert len(x.shape) == 2
assert (
x.shape == residual.shape and x.dtype == residual.dtype
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape
if autotune:
......
......@@ -127,45 +127,21 @@ class RMSNorm(CustomOp):
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
# NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous()
if residual is not None:
try:
output = torch.empty_like(x)
if _vllm_version < Version("0.9"):
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
else:
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
......@@ -175,21 +151,10 @@ class RMSNorm(CustomOp):
self.variance_epsilon,
)
return output, residual_out
except TypeError:
fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out
def forward_native(
self,
x: torch.Tensor,
......
......@@ -31,7 +31,6 @@ from sglang.srt.layers.parameter import (
_ColumnvLLMParameter,
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
if TYPE_CHECKING:
......@@ -626,16 +625,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[output_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, output_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)
# Special case for AQLM codebooks.
elif is_metadata:
......@@ -1310,16 +1302,7 @@ class RowParallelLinear(LinearBase):
shard_size,
)
else:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[input_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, input_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
input_dim, start_idx, shard_size
)
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
......
......@@ -220,7 +220,6 @@ class LogitsProcessor(nn.Module):
self.config = config
self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
......@@ -462,11 +461,7 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"):
if self.use_fp32_lm_head:
logits = torch.matmul(
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
)
elif use_intel_amx_backend(lm_head):
if use_intel_amx_backend(lm_head):
logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
......@@ -480,15 +475,7 @@ class LogitsProcessor(nn.Module):
else:
# GGUF models
# TODO: use weight_packed_linear for GGUF models
if self.use_fp32_lm_head:
with torch.cuda.amp.autocast(enabled=False):
logits = lm_head.quant_method.apply(
lm_head, hidden_states.to(torch.float32), embedding_bias
)
else:
logits = lm_head.quant_method.apply(
lm_head, hidden_states, embedding_bias
)
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
if self.logit_scale is not None:
logits.mul_(self.logit_scale)
......
......@@ -789,45 +789,69 @@ class DeepEPMoE(EPMoE):
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
else:
# dynamic quant
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device
)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
......@@ -836,47 +860,72 @@ class DeepEPMoE(EPMoE):
assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = group_list.to(torch.int64)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
}
}
......@@ -51,14 +51,10 @@ def get_moe_configs(
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
config_dir = os.environ.get(
"SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
)
triton_version = triton.__version__
version_dir = f"triton_{triton_version.replace('.', '_')}"
config_file_path = os.path.join(
config_dir,
os.path.dirname(os.path.realpath(__file__)),
"configs",
version_dir,
json_file_name,
......@@ -79,7 +75,7 @@ def get_moe_configs(
if try_triton_version == triton_version:
continue
try_config_file_path = os.path.join(
config_dir,
os.path.dirname(os.path.realpath(__file__)),
"configs",
f"triton_{try_triton_version.replace('.', '_')}",
json_file_name,
......
......@@ -575,10 +575,7 @@ class FusedMoE(torch.nn.Module):
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if (
should_use_flashinfer_trtllm_moe()
and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
):
if should_use_flashinfer_trtllm_moe():
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
......
......@@ -7,7 +7,6 @@ from typing import Callable, Optional, Union
import torch
from torch.nn import Parameter
from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu
__all__ = [
......@@ -157,17 +156,9 @@ class _ColumnvLLMParameter(BasevLLMParameter):
)
else:
if not use_presharded_weights:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
start_idx = tp_rank * shard_size
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[self.output_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, self.output_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
self.output_dim, start_idx, shard_size
)
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -267,17 +258,9 @@ class RowvLLMParameter(BasevLLMParameter):
return
else:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
start_idx = tp_rank * shard_size
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[self.input_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, self.input_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
self.input_dim, start_idx, shard_size
)
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
......
......@@ -30,7 +30,6 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
)
from sglang.srt.layers.quantization.compressed_tensors.utils import (
......
......@@ -2,12 +2,10 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsW8A8Fp8",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW8A8Int8",
]
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import int8_scaled_mm
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per channel
if self.strategy == QuantizationStrategy.TENSOR:
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
if self.input_symmetric:
layer.input_scale = Parameter(
layer.input_scale.max(), requires_grad=False
)
else:
input_scale = layer.input_scale
input_zero_point = layer.input_zero_point
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
layer.input_scale = Parameter(scale, requires_grad=False)
layer.input_zero_point = Parameter(azp, requires_grad=False)
else:
layer.input_scale = None
layer.input_zero_point = None
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
weight = layer.weight
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = layer.input_zero_point * azp_adj
layer.azp_adj = Parameter(azp_adj, requires_grad=False)
else:
layer.azp_adj = None
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition, input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = PerTensorScaleParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x)
return int8_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
import logging
import torch
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
logger = logging.getLogger(__name__)
......@@ -13,6 +15,7 @@ def _compute_enable_deep_gemm():
try:
import deep_gemm
except ImportError:
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
return False
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
......
......@@ -843,18 +843,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
topk_weights = topk_weights.to(
torch.float32
) # aiter's moe_sorting requires topk_weights to be FP32
if hasattr(torch, "float4_e2m1fn_x2"):
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
else:
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
output = fused_moe(
x,
w13_weight,
w2_weight,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
......
......@@ -183,17 +183,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
moe_runner_config = self.moe_runner_config
topk_weights, topk_ids, _ = topk_output
if hasattr(torch, "float4_e2m1fn_x2"):
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
else:
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
output = fused_moe(
x,
w13_weight,
w2_weight,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
......
......@@ -19,6 +19,10 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import is_npu, set_weight_attrs
_is_npu = is_npu()
if not _is_npu:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment