Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
852a49c5
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
227 additions
and
883 deletions
+227
-883
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+70
-142
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+0
-58
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+0
-8
python/sglang/srt/layers/elementwise.py
python/sglang/srt/layers/elementwise.py
+1
-3
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+10
-45
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+4
-21
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+2
-15
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+120
-71
python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
...me=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
+0
-146
python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
...me=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
+0
-146
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
...rt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
+2
-6
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+1
-4
python/sglang/srt/layers/parameter.py
python/sglang/srt/layers/parameter.py
+6
-23
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+0
-1
python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+0
-2
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+0
-173
python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
...g/srt/layers/quantization/deep_gemm_wrapper/configurer.py
+3
-0
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+2
-10
python/sglang/srt/layers/quantization/quark/quark_moe.py
python/sglang/srt/layers/quantization/quark/quark_moe.py
+2
-9
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+4
-0
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
852a49c5
...
...
@@ -127,8 +127,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"disable_chunked_prefix_cache"
]
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
def
_calc_padded_blocks
(
self
,
max_seq_len
:
int
)
->
int
:
"""
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
...
...
@@ -219,7 +217,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes.
if
not
forward_mode
.
is_decode_or_idle
()
and
not
forward_mode
.
is_target_verify
()
:
if
not
forward_mode
.
is_decode_or_idle
():
return
super
().
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
...
...
@@ -230,9 +228,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info
,
)
if
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
# Custom fast-path for decode/idle.
# Capture with full width so future longer sequences are safe during replay
max_blocks_per_seq
=
self
.
_calc_padded_blocks
(
self
.
max_context_len
)
...
...
@@ -275,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
):
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes.
if
not
forward_mode
.
is_decode_or_idle
()
and
not
forward_mode
.
is_target_verify
()
:
if
not
forward_mode
.
is_decode_or_idle
():
return
super
().
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
...
...
@@ -287,10 +282,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu
,
)
if
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
del
seq_lens_sum
# not handle "num_draft_tokens" but we do not need it
metadata
=
self
.
decode_cuda_graph_metadata
[
bs
]
# Update block indices for new sequences.
...
...
@@ -341,10 +332,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
cum_seq_lens_q
,
seq_lens
,
)
elif
(
forward_batch
.
forward_mode
.
is_decode_or_idle
()
or
forward_batch
.
forward_mode
.
is_target_verify
()
):
elif
forward_batch
.
forward_mode
.
is_decode_or_idle
():
bs
=
forward_batch
.
batch_size
# Get maximum sequence length.
...
...
@@ -353,19 +341,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
else
:
max_seq
=
forward_batch
.
seq_lens
.
max
().
item
()
seq_lens
=
forward_batch
.
seq_lens
if
forward_batch
.
forward_mode
.
is_target_verify
():
max_seq
=
max_seq
+
self
.
num_draft_tokens
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
max_seqlen_pad
=
self
.
_calc_padded_blocks
(
max_seq
)
block_kv_indices
=
self
.
_create_block_kv_indices
(
bs
,
max_seqlen_pad
,
forward_batch
.
req_pool_indices
,
seq_lens
,
seq_lens
.
device
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
.
device
,
)
max_seq_len_val
=
int
(
max_seq
)
...
...
@@ -505,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
q_rope_reshaped
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
query
=
_concat_mla_absorb_q_general
(
q_nope
,
q_rope_reshaped
)
if
_is_cuda
and
q_nope
.
shape
[
-
1
]
==
512
and
q_rope_reshaped
.
shape
[
-
1
]
==
64
:
query
=
concat_mla_absorb_q
(
q_nope
,
q_rope_reshaped
)
else
:
query
=
torch
.
cat
([
q_nope
,
q_rope_reshaped
],
dim
=-
1
)
else
:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
...
...
@@ -568,134 +553,84 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache
:
bool
=
True
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
forward_batch
.
forward_mode
.
is_draft_extend
():
):
if
(
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
):
return
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
# Save KV cache if requested
if
save_kv_cache
:
assert
(
k
is
not
None
and
k_rope
is
not
None
),
"For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
k_rope
)
if
q_rope
is
not
None
:
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
q
=
_concat_mla_absorb_q_general
(
q
,
q_rope
)
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
if
k_rope
is
not
None
:
k
=
torch
.
cat
([
k
,
k_rope
],
dim
=-
1
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
if
forward_batch
.
forward_mode
.
is_target_verify
():
metadata
=
(
getattr
(
forward_batch
,
"decode_trtllm_mla_metadata"
,
None
)
or
self
.
forward_decode_metadata
)
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
bs
=
forward_batch
.
batch_size
q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
kv_cache
=
k_cache
.
view
(
-
1
,
self
.
page_size
,
self
.
kv_cache_dim
).
unsqueeze
(
1
)
q_scale
=
1.0
k_scale
=
(
layer
.
k_scale_float
if
getattr
(
layer
,
"k_scale_float"
,
None
)
is
not
None
else
1.0
)
bmm1_scale
=
q_scale
*
k_scale
*
layer
.
scaling
seq_lens
=
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
)
+
forward_batch
.
spec_info
.
draft_token_num
)
max_seq_len
=
metadata
.
max_seq_len
+
forward_batch
.
spec_info
.
draft_token_num
# TODO may use `mla_rope_quantize_fp8` fusion
q
=
q
.
to
(
self
.
data_type
)
assert
kv_cache
.
dtype
==
self
.
data_type
raw_out
=
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache_mla
(
query
=
q
,
kv_cache
=
kv_cache
,
workspace_buffer
=
self
.
workspace_buffer
,
qk_nope_head_dim
=
self
.
qk_nope_head_dim
,
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
block_tables
=
metadata
.
block_kv_indices
,
seq_lens
=
seq_lens
,
max_seq_len
=
max_seq_len
,
bmm1_scale
=
bmm1_scale
,
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
if
forward_batch
.
attn_attend_prefix_cache
is
None
:
return
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
# Reshape output directly without slicing
output
=
raw_out
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
output
if
forward_batch
.
attn_attend_prefix_cache
:
# MHA for chunked prefix kv cache when running model with MLA
assert
forward_batch
.
prefix_chunk_idx
is
not
None
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
assert
q_rope
is
None
assert
k_rope
is
None
chunk_idx
=
forward_batch
.
prefix_chunk_idx
output_shape
=
(
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
return
flashinfer
.
prefill
.
trtllm_ragged_attention_deepseek
(
if
not
forward_batch
.
attn_attend_prefix_cache
:
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
)
output
=
flashinfer
.
prefill
.
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
value
=
v
,
workspace_buffer
=
self
.
workspace_buffer
,
seq_lens
=
forward_
batch
.
prefix_chunk_seq_lens
[
chunk_idx
]
,
seq_lens
=
self
.
forward_
prefill_metadata
.
seq_lens
,
max_q_len
=
self
.
forward_prefill_metadata
.
max_seq_len
,
max_kv_len
=
forward_
batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
]
,
max_kv_len
=
self
.
forward_
prefill_metadata
.
max_seq_len
,
bmm1_scale
=
layer
.
scaling
,
bmm2_scale
=
1.0
,
o_sf_scale
=
-
1.0
,
o_sf_scale
=
1.0
,
batch_size
=
forward_batch
.
batch_size
,
window_left
=-
1
,
cum_seq_lens_q
=
self
.
forward_prefill_metadata
.
cum_seq_lens
,
cum_seq_lens_kv
=
forward_
batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
]
,
cum_seq_lens_kv
=
self
.
forward_
prefill_metadata
.
cum_seq_lens
,
enable_pdl
=
False
,
is_causal
=
False
,
return_lse
=
True
,
out
=
torch
.
zeros
(
*
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
),
is_causal
=
True
,
return_lse
=
forward_batch
.
mha_return_lse
,
)
return
flashinfer
.
prefill
.
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
value
=
v
,
workspace_buffer
=
self
.
workspace_buffer
,
seq_lens
=
self
.
forward_prefill_metadata
.
seq_lens
,
max_q_len
=
self
.
forward_prefill_metadata
.
max_seq_len
,
max_kv_len
=
self
.
forward_prefill_metadata
.
max_seq_len
,
bmm1_scale
=
layer
.
scaling
,
bmm2_scale
=
1.0
,
o_sf_scale
=
1.0
,
batch_size
=
forward_batch
.
batch_size
,
window_left
=-
1
,
cum_seq_lens_q
=
self
.
forward_prefill_metadata
.
cum_seq_lens
,
cum_seq_lens_kv
=
self
.
forward_prefill_metadata
.
cum_seq_lens
,
enable_pdl
=
False
,
is_causal
=
True
,
return_lse
=
forward_batch
.
mha_return_lse
,
)
else
:
if
not
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
forward_batch
.
mha_return_lse
):
output
=
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
else
:
# MHA for chunked prefix kv cache when running model with MLA
assert
forward_batch
.
prefix_chunk_idx
is
not
None
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
assert
q_rope
is
None
assert
k_rope
is
None
chunk_idx
=
forward_batch
.
prefix_chunk_idx
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
)
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
)
output_shape
=
(
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
output
=
flashinfer
.
prefill
.
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
value
=
v
,
workspace_buffer
=
self
.
workspace_buffer
,
seq_lens
=
forward_batch
.
prefix_chunk_seq_lens
[
chunk_idx
],
max_q_len
=
self
.
forward_prefill_metadata
.
max_seq_len
,
max_kv_len
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
bmm1_scale
=
layer
.
scaling
,
bmm2_scale
=
1.0
,
o_sf_scale
=-
1.0
,
batch_size
=
forward_batch
.
batch_size
,
window_left
=-
1
,
cum_seq_lens_q
=
self
.
forward_prefill_metadata
.
cum_seq_lens
,
cum_seq_lens_kv
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
enable_pdl
=
False
,
is_causal
=
False
,
return_lse
=
True
,
out
=
torch
.
zeros
(
*
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
),
)
return
output
class
TRTLLMMLAMultiStepDraftBackend
(
FlashInferMLAMultiStepDraftBackend
):
...
...
@@ -713,10 +648,3 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
kv_indptr_buf
=
self
.
kv_indptr
[
i
],
q_indptr_decode_buf
=
self
.
q_indptr_decode
,
)
def
_concat_mla_absorb_q_general
(
q_nope
,
q_rope
):
if
_is_cuda
and
q_nope
.
shape
[
-
1
]
==
512
and
q_rope
.
shape
[
-
1
]
==
64
:
return
concat_mla_absorb_q
(
q_nope
,
q_rope
)
else
:
return
torch
.
cat
([
q_nope
,
q_rope
],
dim
=-
1
)
python/sglang/srt/layers/attention/vision.py
View file @
852a49c5
...
...
@@ -16,19 +16,14 @@ from sglang.srt.utils import (
get_device_capability
,
is_blackwell
,
is_cuda
,
is_npu
,
print_info_once
,
)
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
if
_is_cuda
:
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
if
_is_npu
:
import
torch_npu
from
sglang.srt.distributed
import
(
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
...
...
@@ -336,63 +331,10 @@ class VisionFlash3Attention(nn.Module):
return
output
class
VisionAscendAttention
(
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
,
):
if
not
_is_npu
:
raise
Exception
(
"VisionAscendAttention is only available for ascend npu"
)
super
().
__init__
()
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
Union
[
SingletonCache
,
torch
.
Tensor
]],
bsz
:
int
,
seq_len
:
int
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
cu_seqlens: [b]
Returns:
[b * s, h, head_size]
"""
if
cu_seqlens
is
None
:
cu_seqlens
=
_get_cu_seqlens_for_shape
(
bsz
,
seq_len
,
device
=
q
.
device
)
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
if
seq_lens
.
is_npu
:
# cu_seqlens must be on cpu because of operator restriction
seq_lens
=
seq_lens
.
to
(
"cpu"
)
_
,
num_heads
,
head_size
=
q
.
shape
num_kv_heads
=
k
.
shape
[
1
]
output
=
torch
.
empty_like
(
q
)
# operator requires pta version >= 2.5.1
torch_npu
.
_npu_flash_attention_unpad
(
query
=
q
,
key
=
k
,
value
=
v
,
seq_len
=
seq_lens
.
to
(
torch
.
int32
),
scale_value
=
head_size
**-
0.5
,
num_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
out
=
output
,
)
return
output
QKV_BACKEND_IMPL
=
{
"triton_attn"
:
VisionTritonAttention
,
"sdpa"
:
VisionSdpaAttention
,
"fa3"
:
VisionFlash3Attention
,
"ascend_attn"
:
VisionAscendAttention
,
}
...
...
python/sglang/srt/layers/communicator.py
View file @
852a49c5
...
...
@@ -50,7 +50,6 @@ from sglang.srt.utils import (
is_hip
,
is_sm90_supported
,
is_sm100_supported
,
prepare_weight_cache
,
)
_is_flashinfer_available
=
is_flashinfer_available
()
...
...
@@ -276,11 +275,7 @@ class LayerCommunicator:
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
cache
=
None
,
):
if
cache
is
not
None
:
self
.
_context
.
cache
=
cache
return
self
.
_communicate_with_all_reduce_and_layer_norm_fn
(
hidden_states
=
hidden_states
,
residual
=
residual
,
...
...
@@ -354,7 +349,6 @@ class CommunicateContext:
attn_tp_size
:
int
attn_dp_size
:
int
tp_size
:
int
cache
=
None
def
is_same_group_size
(
self
,
a
:
ScatterMode
,
b
:
ScatterMode
):
return
self
.
process_group_sizes
[
a
]
==
self
.
process_group_sizes
[
b
]
...
...
@@ -539,8 +533,6 @@ class CommunicateWithAllReduceAndLayerNormFn:
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
if
context
.
cache
is
not
None
:
_
=
prepare_weight_cache
(
hidden_states
,
context
.
cache
)
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
return
hidden_states
,
residual
...
...
python/sglang/srt/layers/elementwise.py
View file @
852a49c5
...
...
@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
def
fused_dual_residual_rmsnorm
(
x
,
residual
,
weight1
,
weight2
,
eps
,
autotune
=
False
):
assert
len
(
x
.
shape
)
==
2
assert
(
x
.
shape
==
residual
.
shape
and
x
.
dtype
==
residual
.
dtype
),
f
"
{
x
.
shape
=
}
{
residual
.
shape
=
}
{
x
.
dtype
=
}
{
residual
.
dtype
=
}
"
assert
x
.
shape
==
residual
.
shape
and
x
.
dtype
==
residual
.
dtype
output
,
mid
=
torch
.
empty_like
(
x
),
torch
.
empty_like
(
x
)
bs
,
hidden_dim
=
x
.
shape
if
autotune
:
...
...
python/sglang/srt/layers/layernorm.py
View file @
852a49c5
...
...
@@ -127,45 +127,21 @@ class RMSNorm(CustomOp):
return
output
,
residual_out
return
rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
):
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
:
if
not
x
.
is_contiguous
():
# NOTE: Remove this if aiter kernel supports discontinuous input
x
=
x
.
contiguous
()
if
residual
is
not
None
:
try
:
output
=
torch
.
empty_like
(
x
)
if
_vllm_version
<
Version
(
"0.9"
):
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
else
:
residual_out
=
torch
.
empty_like
(
x
)
output
=
torch
.
empty_like
(
x
)
fused_add_rms_norm
(
output
,
x
,
...
...
@@ -175,21 +151,10 @@ class RMSNorm(CustomOp):
self
.
variance_epsilon
,
)
return
output
,
residual_out
except
TypeError
:
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/linear.py
View file @
852a49c5
...
...
@@ -31,7 +31,6 @@ from sglang.srt.layers.parameter import (
_ColumnvLLMParameter
,
)
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.utils
import
pad_or_narrow_weight
from
sglang.srt.utils
import
is_cpu
,
is_npu
,
set_weight_attrs
if
TYPE_CHECKING
:
...
...
@@ -626,16 +625,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
not
use_bitsandbytes_4bit
and
not
self
.
use_presharded_weights
:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
end_idx
=
start_idx
+
shard_size
if
end_idx
>
loaded_weight
.
shape
[
output_dim
]:
loaded_weight
=
pad_or_narrow_weight
(
loaded_weight
,
output_dim
,
start_idx
,
shard_size
)
else
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
elif
is_metadata
:
...
...
@@ -1310,16 +1302,7 @@ class RowParallelLinear(LinearBase):
shard_size
,
)
else
:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
end_idx
=
start_idx
+
shard_size
if
end_idx
>
loaded_weight
.
shape
[
input_dim
]:
loaded_weight
=
pad_or_narrow_weight
(
loaded_weight
,
input_dim
,
start_idx
,
shard_size
)
else
:
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
...
...
python/sglang/srt/layers/logits_processor.py
View file @
852a49c5
...
...
@@ -220,7 +220,6 @@ class LogitsProcessor(nn.Module):
self
.
config
=
config
self
.
logit_scale
=
logit_scale
self
.
use_attn_tp_group
=
global_server_args_dict
[
"enable_dp_lm_head"
]
self
.
use_fp32_lm_head
=
global_server_args_dict
[
"enable_fp32_lm_head"
]
if
self
.
use_attn_tp_group
:
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
do_tensor_parallel_all_gather
=
(
...
...
@@ -462,11 +461,7 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate
(
hidden_states
,
local_hidden_states
,
logits_metadata
)
if
hasattr
(
lm_head
,
"weight"
):
if
self
.
use_fp32_lm_head
:
logits
=
torch
.
matmul
(
hidden_states
.
to
(
torch
.
float32
),
lm_head
.
weight
.
to
(
torch
.
float32
).
T
)
elif
use_intel_amx_backend
(
lm_head
):
if
use_intel_amx_backend
(
lm_head
):
logits
=
torch
.
ops
.
sgl_kernel
.
weight_packed_linear
(
hidden_states
.
to
(
lm_head
.
weight
.
dtype
),
lm_head
.
weight
,
...
...
@@ -480,15 +475,7 @@ class LogitsProcessor(nn.Module):
else
:
# GGUF models
# TODO: use weight_packed_linear for GGUF models
if
self
.
use_fp32_lm_head
:
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
logits
=
lm_head
.
quant_method
.
apply
(
lm_head
,
hidden_states
.
to
(
torch
.
float32
),
embedding_bias
)
else
:
logits
=
lm_head
.
quant_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
logits
=
lm_head
.
quant_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
if
self
.
logit_scale
is
not
None
:
logits
.
mul_
(
self
.
logit_scale
)
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
852a49c5
...
...
@@ -789,45 +789,69 @@ class DeepEPMoE(EPMoE):
if
isinstance
(
hidden_states
,
tuple
):
per_token_scale
=
hidden_states
[
1
]
hidden_states
=
hidden_states
[
0
]
else
:
# dynamic quant
hidden_states
,
per_token_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
)
group_list
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int64
).
to
(
hidden_states
.
device
)
if
self
.
w13_weight
.
dtype
!=
torch
.
int8
:
# gmm1: gate_up_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
.
permute
(
0
,
2
,
1
)],
# per_token_scale=[per_token_scale],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
hidden_states
=
torch_npu
.
npu_swiglu
(
hidden_states
)
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
.
permute
(
0
,
2
,
1
)],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
else
:
if
not
get_bool_env_var
(
"DEEP_NORMAL_MODE_USE_INT8_QUANT"
):
hidden_states
,
per_token_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
)
# gmm1: gate_up_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
scale
=
[
self
.
w13_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
per_token_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
# act_fn: swiglu
hidden_states
=
torch_npu
.
npu_swiglu
(
hidden_states
)
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
)
# gmm1: gate_up_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
scale
=
[
self
.
w13_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
per_token_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
# act_fn: swiglu
hidden_states
=
torch_npu
.
npu_swiglu
(
hidden_states
)
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
)
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
],
scale
=
[
self
.
w2_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
swiglu_out_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
],
scale
=
[
self
.
w2_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
swiglu_out_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
return
hidden_states
...
...
@@ -836,47 +860,72 @@ class DeepEPMoE(EPMoE):
assert
isinstance
(
dispatch_output
,
DeepEPLLOutput
)
hidden_states
,
topk_idx
,
topk_weights
,
group_list
,
_
=
dispatch_output
per_token_scale
=
hidden_states
[
1
]
hidden_states
=
hidden_states
[
0
]
if
isinstance
(
hidden_states
,
tuple
):
per_token_scale
=
hidden_states
[
1
]
hidden_states
=
hidden_states
[
0
]
group_list
=
group_list
.
to
(
torch
.
int64
)
# gmm1: gate_up_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
torch
.
int32
,
)[
0
]
# act_fn: swiglu
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dequant_swiglu_quant
(
x
=
hidden_states
,
weight_scale
=
self
.
w13_weight_scale
.
to
(
torch
.
float32
),
activation_scale
=
per_token_scale
,
bias
=
None
,
quant_scale
=
None
,
quant_offset
=
None
,
group_index
=
group_list
,
activate_left
=
True
,
quant_mode
=
1
,
)
if
self
.
w13_weight
.
dtype
!=
torch
.
int8
:
# gmm1: gate_up_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
.
permute
(
0
,
2
,
1
)],
# per_token_scale=[per_token_scale],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
hidden_states
=
torch_npu
.
npu_swiglu
(
hidden_states
)
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
.
permute
(
0
,
2
,
1
)],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
else
:
# gmm1: gate_up_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
torch
.
int32
,
)[
0
]
# act_fn: swiglu
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dequant_swiglu_quant
(
x
=
hidden_states
,
weight_scale
=
self
.
w13_weight_scale
.
to
(
torch
.
float32
),
activation_scale
=
per_token_scale
,
bias
=
None
,
quant_scale
=
None
,
quant_offset
=
None
,
group_index
=
group_list
,
activate_left
=
True
,
quant_mode
=
1
,
)
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
],
scale
=
[
self
.
w2_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
swiglu_out_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
weight
=
[
self
.
w2_weight
],
scale
=
[
self
.
w2_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
swiglu_out_scale
],
split_item
=
2
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_list
=
group_list
,
output_dtype
=
output_dtype
,
)[
0
]
return
hidden_states
...
...
python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
deleted
100644 → 0
View file @
8f7453e3
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"24"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
}
}
python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
deleted
100644 → 0
View file @
8f7453e3
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
}
}
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
View file @
852a49c5
...
...
@@ -51,14 +51,10 @@ def get_moe_configs(
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
config_dir
=
os
.
environ
.
get
(
"SGLANG_MOE_CONFIG_DIR"
,
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
)
triton_version
=
triton
.
__version__
version_dir
=
f
"triton_
{
triton_version
.
replace
(
'.'
,
'_'
)
}
"
config_file_path
=
os
.
path
.
join
(
config_dir
,
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
,
"configs"
,
version_dir
,
json_file_name
,
...
...
@@ -79,7 +75,7 @@ def get_moe_configs(
if
try_triton_version
==
triton_version
:
continue
try_config_file_path
=
os
.
path
.
join
(
config_dir
,
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
,
"configs"
,
f
"triton_
{
try_triton_version
.
replace
(
'.'
,
'_'
)
}
"
,
json_file_name
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
852a49c5
...
...
@@ -575,10 +575,7 @@ class FusedMoE(torch.nn.Module):
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if
(
should_use_flashinfer_trtllm_moe
()
and
self
.
quant_method
.
__class__
.
__name__
==
"ModelOptNvFp4FusedMoEMethod"
):
if
should_use_flashinfer_trtllm_moe
():
shard_id
=
{
"w1"
:
"w3"
,
"w3"
:
"w1"
,
"w2"
:
"w2"
}[
shard_id
]
WEIGHT_SCALE_SUPPORTED
=
[
e
.
value
for
e
in
FusedMoeWeightScaleSupported
]
...
...
python/sglang/srt/layers/parameter.py
View file @
852a49c5
...
...
@@ -7,7 +7,6 @@ from typing import Callable, Optional, Union
import
torch
from
torch.nn
import
Parameter
from
sglang.srt.layers.utils
import
pad_or_narrow_weight
from
sglang.srt.utils
import
is_cpu
__all__
=
[
...
...
@@ -157,17 +156,9 @@ class _ColumnvLLMParameter(BasevLLMParameter):
)
else
:
if
not
use_presharded_weights
:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
start_idx
=
tp_rank
*
shard_size
end_idx
=
start_idx
+
shard_size
if
end_idx
>
loaded_weight
.
shape
[
self
.
output_dim
]:
loaded_weight
=
pad_or_narrow_weight
(
loaded_weight
,
self
.
output_dim
,
start_idx
,
shard_size
)
else
:
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
start_idx
,
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
tp_rank
*
shard_size
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -267,17 +258,9 @@ class RowvLLMParameter(BasevLLMParameter):
return
else
:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
start_idx
=
tp_rank
*
shard_size
end_idx
=
start_idx
+
shard_size
if
end_idx
>
loaded_weight
.
shape
[
self
.
input_dim
]:
loaded_weight
=
pad_or_narrow_weight
(
loaded_weight
,
self
.
input_dim
,
start_idx
,
shard_size
)
else
:
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
start_idx
,
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
tp_rank
*
shard_size
,
shard_size
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
852a49c5
...
...
@@ -30,7 +30,6 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
)
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
(
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
852a49c5
...
...
@@ -2,12 +2,10 @@
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_int8
import
CompressedTensorsW8A8Int8
from
.compressed_tensors_w8a16_fp8
import
CompressedTensorsW8A16Fp8
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW8A8Int8"
,
]
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
deleted
100644 → 0
View file @
8f7453e3
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
torch.nn
import
Parameter
from
sglang.srt.layers.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
)
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
,
input_symmetric
:
bool
):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
input_symmetric
=
input_symmetric
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# lovelace and up
return
89
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per channel
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
max_w_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# If channelwise, scales are already lined up, so just transpose.
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
.
data
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
# INPUT SCALE
if
self
.
is_static_input_scheme
and
hasattr
(
layer
,
"input_scale"
):
if
self
.
input_symmetric
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
input_scale
=
layer
.
input_scale
input_zero_point
=
layer
.
input_zero_point
# reconstruct the ranges
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
azps
=
input_zero_point
.
to
(
dtype
=
torch
.
int32
)
range_max
=
(
input_scale
*
(
int8_traits
.
max
-
azps
)).
max
()
range_min
=
(
input_scale
*
(
int8_traits
.
min
-
azps
)).
min
()
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
# AZP loaded as int8 but used as int32
azp
=
(
int8_traits
.
min
-
range_min
/
scale
).
to
(
dtype
=
torch
.
int32
)
layer
.
input_scale
=
Parameter
(
scale
,
requires_grad
=
False
)
layer
.
input_zero_point
=
Parameter
(
azp
,
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
layer
.
input_zero_point
=
None
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if
not
self
.
input_symmetric
:
weight
=
layer
.
weight
azp_adj
=
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
if
self
.
is_static_input_scheme
:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj
=
layer
.
input_zero_point
*
azp_adj
layer
.
azp_adj
=
Parameter
(
azp_adj
,
requires_grad
=
False
)
else
:
layer
.
azp_adj
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
,
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
int8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
if
not
self
.
input_symmetric
:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
# TODO: add cutlass_scaled_mm_azp support
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
return
int8_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
View file @
852a49c5
import
logging
import
torch
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_sm
,
is_blackwell
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -13,6 +15,7 @@ def _compute_enable_deep_gemm():
try
:
import
deep_gemm
except
ImportError
:
logger
.
warning
(
"Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM."
)
return
False
return
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"true"
)
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
852a49c5
...
...
@@ -843,18 +843,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
# aiter's moe_sorting requires topk_weights to be FP32
if
hasattr
(
torch
,
"float4_e2m1fn_x2"
):
w13_weight
=
layer
.
w13_weight
.
view
(
torch
.
float4_e2m1fn_x2
)
w2_weight
=
layer
.
w2_weight
.
view
(
torch
.
float4_e2m1fn_x2
)
else
:
w13_weight
=
layer
.
w13_weight
w2_weight
=
layer
.
w2_weight
output
=
fused_moe
(
x
,
w13_weight
,
w2_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
quant_type
=
QuantType
.
per_1x32
,
...
...
python/sglang/srt/layers/quantization/quark/quark_moe.py
View file @
852a49c5
...
...
@@ -183,17 +183,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
moe_runner_config
=
self
.
moe_runner_config
topk_weights
,
topk_ids
,
_
=
topk_output
if
hasattr
(
torch
,
"float4_e2m1fn_x2"
):
w13_weight
=
layer
.
w13_weight
.
view
(
torch
.
float4_e2m1fn_x2
)
w2_weight
=
layer
.
w2_weight
.
view
(
torch
.
float4_e2m1fn_x2
)
else
:
w13_weight
=
layer
.
w13_weight
w2_weight
=
layer
.
w2_weight
output
=
fused_moe
(
x
,
w13_weight
,
w2_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
quant_type
=
QuantType
.
per_1x32
,
...
...
python/sglang/srt/layers/quantization/w4afp8.py
View file @
852a49c5
...
...
@@ -19,6 +19,10 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.utils
import
is_npu
,
set_weight_attrs
_is_npu
=
is_npu
()
if
not
_is_npu
:
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe
import
MoeRunnerConfig
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment