Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c1d2061f
Unverified
Commit
c1d2061f
authored
Aug 05, 2025
by
Ying Sheng
Committed by
GitHub
Aug 05, 2025
Browse files
Add initial support for gpt-oss (#8824)
parent
556e4143
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1595 additions
and
47 deletions
+1595
-47
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+85
-14
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+17
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+36
-8
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+0
-5
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+134
-8
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
...ang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
+178
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+29
-0
python/sglang/srt/layers/quantization/mxfp4_tensor.py
python/sglang/srt/layers/quantization/mxfp4_tensor.py
+133
-0
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+52
-7
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-2
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+923
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-0
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
c1d2061f
...
...
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
window_kv_indptr
=
torch
.
zeros_like
(
kv_indptr_buf
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
token_to_kv_pool_allocator
=
model_runner
.
token_to_kv_pool_allocator
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
torch
.
zeros
(
...
...
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
)
window_num_kv_splits
=
torch
.
empty
(
...
...
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
None
max_extend_len
=
None
elif
forward_batch
.
forward_mode
.
is_target_verify
():
# TODO: Support sliding window in spec inference
bs
=
len
(
forward_batch
.
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
0
,
...
...
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
=
(
update_sliding_window_buffer
(
self
.
window_kv_indptr
,
self
.
req_to_token
,
self
.
sliding_window_size
,
forward_batch
.
seq_lens
,
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
)
custom_mask
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
...
...
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
qo_indptr
=
self
.
qo_indptr
...
...
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
):
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indptr
,
_
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
[:
bs
],
req_pool_indices
,
bs
,
window_kv_indptr
,
window_kv_indices
,
_
=
(
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
[:
bs
],
req_pool_indices
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
...
...
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indptr
,
window_kv_indices
,
_
=
(
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
...
...
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
):
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
_
,
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
...
...
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens
[:
bs
],
req_pool_indices
[:
bs
],
bs
,
self
.
token_to_kv_pool_allocator
,
)
self
.
get_num_kv_splits
(
window_num_kv_splits
[:
num_token
],
window_kv_lens
[:
bs
]
...
...
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
_
,
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
...
...
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
sk
=
None
,
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
...
...
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
sliding_window_size
,
sliding_window_size
=
sliding_window_size
,
sk
=
sk
,
)
return
o
...
...
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
sk
=
None
,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
...
...
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
logit_cap
,
sk
=
sk
,
)
return
o
...
...
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
req_pool_indices
,
bs
,
device
,
token_to_kv_pool_allocator
=
None
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
torch
.
tensor
(
sliding_window_size
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
...
...
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
# full to swa index mapping
if
hasattr
(
token_to_kv_pool_allocator
,
"translate_loc_from_full_to_swa"
):
kv_last_index
=
window_kv_indptr
[
-
1
]
window_kv_indices
[:
kv_last_index
]
=
(
token_to_kv_pool_allocator
.
translate_loc_from_full_to_swa
(
window_kv_indices
[:
kv_last_index
]
)
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
...
...
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
seq_lens
,
req_pool_indices
,
bs
,
token_to_kv_pool_allocator
=
None
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
torch
.
tensor
(
sliding_window_size
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
...
...
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_lens
# full to swa index mapping
if
hasattr
(
token_to_kv_pool_allocator
,
"translate_loc_from_full_to_swa"
):
kv_last_index
=
window_kv_indptr
[
-
1
]
window_kv_indices
[:
kv_last_index
]
=
(
token_to_kv_pool_allocator
.
translate_loc_from_full_to_swa
(
window_kv_indices
[:
kv_last_index
]
)
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c1d2061f
...
...
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
O
,
kv_indptr
,
num_kv_splits
,
sk_ptr
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
...
...
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
HAS_SK
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
if
HAS_SK
:
cur_sk
=
tl
.
load
(
sk_ptr
+
cur_head
)
e_sum
+=
tl
.
exp
(
cur_sk
-
e_max
)
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
...
...
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
sk
=
None
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
MAX_KV_SPLITS
=
max_kv_splits
HAS_SK
=
sk
is
not
None
extra_kargs
=
{}
if
_is_hip
:
...
...
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
o
,
kv_indptr
,
num_kv_splits
,
sk
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
...
...
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
HAS_SK
=
HAS_SK
,
num_warps
=
4
,
num_stages
=
2
,
**
extra_kargs
,
...
...
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
sk
=
None
,
):
_decode_att_m_fwd
(
q
,
...
...
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
sk
,
)
...
...
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
sk
=
None
,
):
_decode_grouped_att_m_fwd
(
q
,
...
...
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
sk
,
)
...
...
@@ -687,6 +701,7 @@ def decode_attention_fwd(
max_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
sk
=
None
,
):
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
...
...
@@ -709,6 +724,7 @@ def decode_attention_fwd(
max_kv_splits
,
sm_scale
,
logit_cap
=
logit_cap
,
sk
=
sk
,
)
else
:
# GQA/MQA/MLA
...
...
@@ -725,4 +741,5 @@ def decode_attention_fwd(
max_kv_splits
,
sm_scale
,
logit_cap
=
logit_cap
,
sk
=
sk
,
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
c1d2061f
...
...
@@ -51,6 +51,7 @@ def _fwd_kernel(
kv_indices
,
mask_ptr
,
mask_indptr
,
sk_ptr
,
sm_scale
,
kv_group_num
,
stride_qbs
,
...
...
@@ -78,6 +79,7 @@ def _fwd_kernel(
IS_CAUSAL
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
HAS_SK
:
tl
.
constexpr
,
):
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -178,13 +180,17 @@ def _fwd_kernel(
final_mask
&=
custom_mask
if
SLIDING_WINDOW_SIZE
>
0
:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
# q_id = prefix_len + cur_m, kv_id = cur_n
window_mask
=
(
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
]
)
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
...
...
@@ -242,6 +248,7 @@ def _fwd_kernel(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
mask_ptr
...
...
@@ -254,18 +261,30 @@ def _fwd_kernel(
other
=
0
,
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
q
k
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
final_mas
k
&
=
custom_mask
elif
IS_CAUSAL
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
q
k
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
final_mas
k
&
=
mask_causual
else
:
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_non_causal
,
qk
,
float
(
"-inf"
))
final_mask
&=
mask_non_causal
if
SLIDING_WINDOW_SIZE
>
0
:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
...
...
@@ -283,6 +302,10 @@ def _fwd_kernel(
e_max
=
n_e_max
if
HAS_SK
:
cur_sk
=
tl
.
load
(
sk_ptr
+
cur_head
)
deno
+=
tl
.
exp
(
cur_sk
-
e_max
)
offs_o
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
...
...
@@ -321,6 +344,7 @@ def extend_attention_fwd(
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
sliding_window_size
=-
1
,
sk
=
None
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
...
@@ -386,6 +410,8 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
HAS_SK
=
sk
is
not
None
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
...
...
@@ -405,6 +431,7 @@ def extend_attention_fwd(
kv_indices
,
custom_mask
,
mask_indptr
,
sk
,
sm_scale
,
kv_group_num
,
q_extend
.
stride
(
0
),
...
...
@@ -431,6 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
IS_CAUSAL
=
is_causal
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
HAS_SK
=
HAS_SK
,
STORE_TRANSPOSE
=
_is_hip
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
python/sglang/srt/layers/linear.py
View file @
c1d2061f
...
...
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
else
self
.
weight_loader
),
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
c1d2061f
...
...
@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_flashinfer_cutlass_moe
:
Optional
[
bool
]
=
False
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
use_weight_loader_fused
:
bool
=
False
,
with_bias
=
False
,
):
super
().
__init__
()
...
...
@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
self
.
expert_map_cpu
=
None
self
.
expert_map_gpu
=
None
# For activation
self
.
activation_alpha
=
activation_alpha
self
.
swiglu_limit
=
swiglu_limit
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
enable_flashinfer_cutlass_moe
=
False
...
...
@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
use_triton_kernels
self
.
use_triton_kernels
,
with_bias
=
with_bias
)
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
...
...
@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
intermediate_size
=
self
.
intermediate_size_per_partition
,
intermediate_size_per_partition
=
self
.
intermediate_size_per_partition
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
,
weight_loader
=
(
self
.
weight_loader
if
not
use_weight_loader_fused
else
self
.
weight_loader_fused
),
with_bias
=
with_bias
,
)
def
_load_per_tensor_weight_scale
(
...
...
@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
# Load grouped weight scales for group quantization
# or model weights
...
...
@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
elif
shard_id
in
(
"w1"
,
"w3"
):
elif
shard_id
in
(
"w1"
,
"w3"
,
"w13"
):
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
def
_load_per_channel_weight_scale
(
...
...
@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
assert
shard_id
in
{
"w1"
,
"w3"
,
"w13"
}
if
is_bias
:
# if this weight is a bias, the last dimension must be the sharded dimension
shard_dim
=
-
1
if
shard_id
in
{
"w1"
,
"w3"
}:
# non-fused version
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
elif
shard_id
in
{
"w13"
}:
# fused version
shard_size
=
expert_data
.
shape
[
shard_dim
]
else
:
raise
NotImplementedError
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
assert
shard_id
in
(
"w1"
,
"w3"
)
switch_w13
=
getattr
(
self
.
quant_method
,
"load_up_proj_weight_first"
,
False
)
if
(
switch_w13
and
shard_id
==
"w1"
)
or
(
not
switch_w13
and
shard_id
==
"w3"
):
start
=
shard_size
...
...
@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
)
else
:
if
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
if
not
is_bias
and
self
.
use_triton_kernels
:
# do not transpose for bias
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
...
...
@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
"""Load w2 weights for down projection.
...
...
@@ -356,7 +387,14 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size
=
expert_data
.
shape
[
shard_dim
]
if
is_bias
:
# this expert_data is a bias, not weight,
# for w2_bias in TP, it does not need to be sharded
shard_size
=
expert_data
.
shape
[
-
1
]
else
:
# this parameter is a weight matrix
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
shard_size
=
expert_data
.
shape
[
shard_dim
]
if
_is_cpu
:
expert_data
,
loaded_weight
=
narrow_padded_param_and_loaded_weight
(
...
...
@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
not
self
.
use_presharded_weights
,
)
else
:
if
not
self
.
use_presharded_weights
:
if
not
is_bias
and
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
if
shard_size
*
tp_rank
+
shard_size
>
loaded_weight
.
shape
[
shard_dim
]:
...
...
@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
)
return
def
weight_loader_fused
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
)
->
None
:
tp_rank
=
self
.
moe_tp_rank
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight
=
(
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsWNA16MoEMethod"
)
else
loaded_weight
)
if
shard_id
not
in
(
"w13"
,
"w2"
):
raise
ValueError
(
f
"shard_id must be ['w13','w2'] but "
f
"got
{
shard_id
}
."
)
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM
=
{
"w13"
:
1
,
"w2"
:
2
}
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE
=
{
"w13"
:
2
,
"w2"
:
1
}
expert_data
=
param
.
data
is_bias
=
expert_data
.
dim
()
==
2
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
if
self
.
use_triton_kernels
:
is_transposed
=
True
shard_dim
=
(
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
not
is_transposed
else
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE
[
shard_id
]
)
# Case model weights
if
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
return
else
:
logging
.
warning
(
f
"Unsupported weight_name
{
weight_name
}
for FusedMoE weight_loader_fused. Nothing is loaded."
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
StandardTopKOutput
):
assert
self
.
quant_method
is
not
None
...
...
@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
# Matrix multiply.
with
use_symmetric_memory
(
get_tp_group
())
as
sm
:
kwargs
=
{}
if
self
.
activation_alpha
is
not
None
:
kwargs
[
"activation_alpha"
]
=
self
.
activation_alpha
if
self
.
swiglu_limit
is
not
None
:
kwargs
[
"swiglu_limit"
]
=
self
.
swiglu_limit
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
x
=
hidden_states
,
...
...
@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
==
"ModelOptNvFp4FusedMoEMethod"
else
{}
),
**
kwargs
,
)
sm
.
tag
(
final_hidden_states
)
...
...
@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
]
]
@
classmethod
def
make_expert_params_mapping_fused
(
cls
,
ckpt_gate_up_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_gate_up_proj_bias_name
:
str
,
ckpt_down_proj_bias_name
:
str
,
):
return
[
(
"experts.w13_weight"
,
f
"experts.
{
ckpt_gate_up_proj_name
}
"
,
"w13"
),
(
"experts.w13_weight_bias"
,
f
"experts.
{
ckpt_gate_up_proj_bias_name
}
"
,
"w13"
,
),
(
"experts.w2_weight"
,
f
"experts.
{
ckpt_down_proj_name
}
"
,
"w2"
),
(
"experts.w2_weight_bias"
,
f
"experts.
{
ckpt_down_proj_bias_name
}
"
,
"w2"
),
]
@
classmethod
def
make_expert_input_scale_params_mapping
(
cls
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
View file @
c1d2061f
...
...
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
import
torch
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
triton_kernels.matmul_ogs
import
matmul_ogs
from
triton_kernels.matmul_ogs
import
(
FlexCtx
,
FnSpecs
,
FusedActivation
,
PrecisionConfig
,
matmul_ogs
,
)
from
triton_kernels.numerics
import
InFlexData
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
from
sglang.srt.utils
import
direct_register_custom_op
from
triton_kernels.swiglu
import
swiglu_fn
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
def
quantize
(
w
,
dtype
,
dev
,
**
opt
):
if
dtype
==
"bf16"
:
return
w
.
to
(
torch
.
bfloat16
),
InFlexData
()
elif
dtype
==
"fp8"
:
wq
=
w
.
to
(
torch
.
float8_e4m3fn
).
transpose
(
-
1
,
-
2
).
contiguous
().
transpose
(
-
1
,
-
2
)
return
(
wq
,
InFlexData
(
dtype
=
wq
.
dtype
,
scale
=
w
.
abs
().
max
().
unsqueeze
(
0
)),
MicroscalingCtx
(),
)
else
:
assert
dtype
==
"mx4"
,
f
"
{
dtype
=
}
"
swizzle_mx_scale
=
opt
[
"swizzle_mx_scale"
]
swizzle_axis
=
2
if
swizzle_mx_scale
else
None
w
=
w
.
to
(
torch
.
bfloat16
)
w
,
mx_scales
,
weight_scale_shape
=
downcast_to_mxfp
(
w
,
torch
.
uint8
,
axis
=
1
,
swizzle_axis
=
swizzle_axis
)
return
(
w
,
InFlexData
(),
MicroscalingCtx
(
weight_scale
=
mx_scales
,
swizzle_mx
=
swizzle_mx_scale
,
actual_weight_scale_shape
=
weight_scale_shape
,
),
)
def
triton_kernel_moe_forward
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
)
return
intermediate_cache3
def
triton_kernel_moe_with_bias_forward
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
b2
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
assert
topk_output
.
format
.
is_triton_kernel
()
routing_data
,
gather_idx
,
scatter_idx
=
topk_output
return
triton_kernel_fused_experts_with_bias
(
hidden_states
,
w1
,
b1
,
w2
,
b2
,
routing_data
,
gather_idx
,
scatter_idx
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
def
triton_kernel_fused_experts_with_bias
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
b2
:
torch
.
Tensor
,
routing_data
:
RoutingData
,
gather_indx
:
GatherIndx
,
scatter_indx
:
ScatterIndx
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert
use_fp8_w8a8
==
False
,
"use_fp8_w8a8 is not supported"
assert
per_channel_quant
==
False
,
"per_channel_quant is not supported"
assert
expert_map
==
None
,
"expert_map is not supported"
assert
w1_scale
==
None
,
"w1_scale is not supported"
assert
w2_scale
==
None
,
"w2_scale is not supported"
assert
a1_scale
==
None
,
"a1_scale is not supported"
assert
a2_scale
==
None
,
"a2_scale is not supported"
assert
block_shape
==
None
,
"block_shape is not supported"
# type check
assert
hidden_states
.
dtype
==
torch
.
bfloat16
,
"hidden_states must be bfloat16"
assert
w1
.
dtype
==
torch
.
bfloat16
,
"w1 must be bfloat16"
assert
w2
.
dtype
==
torch
.
bfloat16
,
"w2 must be bfloat16"
# Shape check
assert
hidden_states
.
ndim
==
2
,
"hidden_states must be 2D"
assert
(
hidden_states
.
shape
[
-
1
]
==
w1
.
shape
[
-
2
]
),
f
"hidden_states shape[-1]
{
hidden_states
.
shape
}
must be equal to w1 shape[-2]
{
w1
.
shape
}
"
assert
(
w2
.
shape
[
-
1
]
==
w1
.
shape
[
1
]
),
f
"w2 shape[-1]
{
w2
.
shape
[
-
1
]
}
must be equal to w1 shape[1]
{
w1
.
shape
[
1
]
}
"
# feature check
assert
inplace
==
False
,
"Inplace is not supported in new triton MoE kernel"
E
,
_
,
_
=
w1
.
shape
if
global_num_experts
==
-
1
:
global_num_experts
=
E
device
=
"cuda"
optg
=
dict
()
w1
,
w1_flex
=
quantize
(
w1
,
"bf16"
,
device
,
**
optg
)
w1_pcg
=
PrecisionConfig
(
flex_ctx
=
FlexCtx
(
rhs_data
=
w1_flex
))
w2
,
w2_flex
=
quantize
(
w2
,
"bf16"
,
device
,
**
optg
)
w2_pcg
=
PrecisionConfig
(
flex_ctx
=
FlexCtx
(
rhs_data
=
w2_flex
))
act
=
FusedActivation
(
FnSpecs
(
"swiglu"
,
swiglu_fn
,
(
"alpha"
,
"limit"
)),
(
activation_alpha
,
swiglu_limit
),
2
,
)
intermediate_cache
=
matmul_ogs
(
hidden_states
,
w1
,
b1
,
routing_data
,
gather_indx
=
gather_indx
,
precision_config
=
w1_pcg
,
gammas
=
None
,
fused_activation
=
act
,
)
return
matmul_ogs
(
intermediate_cache
,
w2
,
b2
,
routing_data
,
scatter_indx
=
scatter_indx
,
precision_config
=
w2_pcg
,
gammas
=
routing_data
.
gate_scal
,
)
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
c1d2061f
...
...
@@ -4,6 +4,7 @@ import torch
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.mxfp4_tensor
import
MXFP4QuantizeUtil
from
sglang.srt.layers.utils
import
is_sm100_supported
try
:
...
...
@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from
sglang.srt.utils
import
(
align
,
ceil_div
,
get_bool_env_var
,
get_cuda_version
,
get_device_capability
,
...
...
@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
def
dequant_mxfp4
(
w_block
:
torch
.
Tensor
,
w_scale
:
torch
.
Tensor
,
out_dtype
,
)
->
torch
.
Tensor
:
"""
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
:param w_scale: (batch, n, k), uint8
:return: (batch, n, k * 32), float32
"""
assert
w_block
.
dtype
==
torch
.
uint8
assert
w_scale
.
dtype
==
torch
.
uint8
batch
,
n
,
k
,
pack_dim
=
w_block
.
shape
batch_
,
n_
,
k_
=
w_scale
.
shape
assert
pack_dim
==
16
assert
batch
==
batch_
assert
n
==
n_
assert
k
==
k_
out_raw
=
MXFP4QuantizeUtil
.
dequantize
(
quantized_data
=
w_block
,
scale
=
w_scale
,
dtype
=
out_dtype
,
block_sizes
=
[
32
]
)
return
out_raw
.
reshape
(
batch
,
n
,
k
*
32
)
def
input_to_float8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_dtype
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
python/sglang/srt/layers/quantization/mxfp4_tensor.py
0 → 100644
View file @
c1d2061f
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
class
MXFP4QuantizeUtil
:
E2M1_max
=
6.0
E2M1_values
=
[
0
,
0.5
,
1
,
1.5
,
2
,
3
,
4
,
6
]
E2M1_bounds
=
torch
.
tensor
([
0.25
,
0.75
,
1.25
,
1.75
,
2.5
,
3.5
,
5
])
@
classmethod
def
quantize
(
cls
,
input
:
torch
.
Tensor
,
block_size
:
int
|
None
)
->
tuple
:
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
Args:
input (torch.Tensor): The input tensor to be quantized.
block_sizes (dict | None): The block sizes for quantization.
"""
def
cast_fp4
(
x
):
sign
=
torch
.
sign
(
x
)
sign_bit
=
(
2
-
sign
)
//
2
ord_
=
torch
.
sum
(
(
x
.
abs
().
unsqueeze
(
-
1
)
-
cls
.
E2M1_bounds
.
to
(
x
.
device
))
>
0
,
dim
=-
1
)
fp4_val
=
(
sign_bit
*
0b1000
+
ord_
).
to
(
torch
.
uint8
)
return
fp4_val
def
fuse_uint4_to_uint8
(
x
):
# If the last dimension is odd, pad with zeros
# If this behavior is not desired, please modify the code accordingly
left_side
=
x
[...,
0
::
2
]
# Even indices (0, 2, 4...)
right_side
=
x
[...,
1
::
2
]
# Odd indices (1, 3, 5...)
new_data
=
(
right_side
.
clone
()
<<
4
)
# Put odd indices (higher addresses) in high bits
new_data
[
...,
:
left_side
.
shape
[
-
1
]
]
+=
left_side
# Put even indices in low bits
return
new_data
if
block_size
is
None
:
block_size
=
32
original_shape
=
input
.
shape
original_dtype
=
input
.
dtype
input
=
input
.
view
(
-
1
,
block_size
)
# get scales
input_amax
=
input
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
descale
=
input_amax
/
cls
.
E2M1_max
min_value
=
torch
.
tensor
(
-
127.0
,
device
=
descale
.
device
)
e8m0_scale
=
torch
.
ceil
(
torch
.
maximum
(
torch
.
log2
(
descale
),
min_value
))
input
=
(
input
/
torch
.
exp2
(
e8m0_scale
)).
view
(
original_shape
)
input_q
=
cast_fp4
(
input
)
input_q
=
fuse_uint4_to_uint8
(
input_q
)
e8m0_scale
=
(
e8m0_scale
+
127
).
to
(
torch
.
uint8
)
return
cls
(
original_shape
,
original_dtype
,
input_q
),
e8m0_scale
@
classmethod
def
dequantize
(
cls
,
quantized_data
,
dtype
:
torch
.
dtype
,
scale
,
block_sizes
):
"""Dequantze MXFP4 packed tensor to a target dtype."""
def
unfuse_uint8_to_uint4
(
x
):
"""Unfuse uint8 values back to uint4 values.
This is the inverse operation of fuse_uint4_to_uint8.
"""
# Extract the lower 4 bits (even indices)
left_side
=
x
&
0x0F
# Extract the upper 4 bits (odd indices)
right_side
=
(
x
>>
4
)
&
0x0F
# Create a new tensor with alternating values
shape
=
list
(
x
.
shape
)
shape
[
-
1
]
=
shape
[
-
1
]
*
2
result
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
uint8
,
device
=
x
.
device
)
# Fill in the values - even indices get low bits, odd indices get high bits
result
[...,
0
::
2
]
=
left_side
# Even indices from low bits
result
[...,
1
::
2
]
=
right_side
# Odd indices from high bits
return
result
e8m0_scale
=
scale
block_size
=
block_sizes
[
-
1
]
# Unfuse the uint8 values back to uint4
x_unfused
=
unfuse_uint8_to_uint4
(
quantized_data
)
# Extract sign and magnitude
sign
=
1
-
2
*
((
x_unfused
&
0b1000
)
>>
3
).
to
(
torch
.
float32
)
# Extract sign bit and convert to +1/-1
magnitude
=
x_unfused
&
0b0111
# Extract magnitude bits
magnitude
=
magnitude
.
to
(
torch
.
long
)
# Create a tensor with the E2M1 values
values
=
torch
.
tensor
(
cls
.
E2M1_values
,
device
=
quantized_data
.
device
)
# Use gather to index the values tensor properly
# We need to reshape magnitude to match the dimensions we want to gather along
original_shape
=
magnitude
.
shape
x_float
=
values
[
magnitude
.
reshape
(
-
1
)].
reshape
(
original_shape
)
# Apply sign and scale
x_float
=
sign
.
float
()
*
x_float
# Reshape to apply block-wise scaling
x_float
=
x_float
.
reshape
(
-
1
,
block_size
)
# Apply the E8M0 scale
scale_factor
=
torch
.
exp2
(
e8m0_scale
.
float
()
-
127
)
scale_factor
=
scale_factor
.
reshape
(
-
1
,
1
)
# Reshape for proper broadcasting
# Apply scaling and reshape back to original shape
x_float
=
x_float
*
scale_factor
# Reshape back to the original shape
return
x_float
.
reshape
(
original_shape
).
to
(
dtype
)
python/sglang/srt/layers/quantization/unquant.py
View file @
c1d2061f
...
...
@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
):
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
,
with_bias
:
bool
=
False
):
super
().
__init__
()
self
.
use_triton_kernels
=
use_triton_kernels
self
.
with_bias
=
with_bias
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
if
torch
.
cuda
.
is_available
()
and
has_triton_kernels
:
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
as
_tk_forward
,
)
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_with_bias_forward
as
_tk_with_bias_forward
,
)
self
.
triton_kernel_moe_forward
=
_tk_forward
self
.
triton_kernel_moe_with_bias_forward
=
_tk_with_bias_forward
def
create_weights
(
self
,
...
...
@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
if
self
.
with_bias
:
w13_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_bias"
,
w13_weight_bias
)
set_weight_attrs
(
w13_weight_bias
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight_n
,
w2_weight_k
=
(
hidden_size
,
...
...
@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
with_bias
:
w2_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_bias"
,
w2_weight_bias
)
set_weight_attrs
(
w2_weight_bias
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
_use_aiter
:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
...
...
@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
kwargs
=
{}
if
activation_alpha
is
not
None
:
kwargs
[
"activation_alpha"
]
=
activation_alpha
if
swiglu_limit
is
not
None
:
kwargs
[
"swiglu_limit"
]
=
swiglu_limit
return
self
.
forward
(
x
=
x
,
...
...
@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
=
inplace
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
**
kwargs
,
)
def
forward_cuda
(
...
...
@@ -226,15 +256,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
use_triton_kernels
:
return
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
)
if
self
.
with_bias
:
return
self
.
triton_kernel_moe_with_bias_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
topk_output
=
topk_output
,
activation
=
activation
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
else
:
return
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
)
else
:
if
_use_aiter
:
assert
not
no_combine
,
"unsupported"
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
c1d2061f
...
...
@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_hybrid
=
False
if
isinstance
(
token_to_kv_pool_allocator
,
SWATokenToKVPoolAllocator
):
assert
isinstance
(
tree_cache
,
SWARadixCache
)
or
isinstance
(
tree_cache
,
SWAChunkCache
assert
(
tree_cache
is
None
or
isinstance
(
tree_cache
,
SWARadixCache
)
or
isinstance
(
tree_cache
,
SWAChunkCache
)
),
"SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
is_hybrid
=
True
...
...
python/sglang/srt/models/gpt_oss.py
0 → 100644
View file @
c1d2061f
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only GptOss model compatible with HuggingFace weights."""
import
logging
from
collections.abc
import
Iterable
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
get_moe_tensor_parallel_rank
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
DeepEPMode
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_utils
import
dequant_mxfp4
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
make_layers
class
GptOssConfig
(
PretrainedConfig
):
model_type
=
"gpt_oss"
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
logger
=
logging
.
getLogger
(
__name__
)
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def
get_attention_sliding_window_size
(
config
):
return
config
.
sliding_window
-
1
class
GptOssSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
:
int
,
config
:
GptOssConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
layer_id
=
layer_id
self
.
activation
=
config
.
hidden_act
self
.
activation_alpha
=
getattr
(
config
,
"hidden_act_alpha"
,
1.702
)
self
.
swiglu_limit
=
config
.
swiglu_limit
if
self
.
tp_size
>
config
.
num_local_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
num_local_experts
}
."
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
True
,
)
experts_type
=
get_moe_impl_class
()
extra_kwargs
=
{}
if
experts_type
.
__name__
==
"FusedMoE"
:
extra_kwargs
=
{
"enable_flashinfer_cutlass_moe"
:
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
],
"use_weight_loader_fused"
:
True
,
# for moe gate_up_proj and down_proj and their bias loading
}
self
.
experts
=
experts_type
(
num_experts
=
config
.
num_local_experts
+
global_server_args_dict
[
"ep_num_redundant_experts"
],
top_k
=
config
.
num_experts_per_tok
,
layer_id
=
layer_id
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
activation
=
self
.
activation
,
activation_alpha
=
self
.
activation_alpha
,
swiglu_limit
=
self
.
swiglu_limit
,
with_bias
=
True
,
prefix
=
add_prefix
(
"experts"
,
prefix
),
**
(
dict
(
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]])
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
()
else
{}
),
**
extra_kwargs
,
)
self
.
router
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
num_local_experts
,
bias
=
True
,
quant_config
=
None
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
params_dtype
=
config
.
torch_dtype
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
)
->
torch
.
Tensor
:
if
not
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
():
return
self
.
forward_normal
(
hidden_states
)
else
:
raise
Exception
(
"forward_deepep branch not implemented yet"
)
def
get_moe_weights
(
self
):
return
[
x
.
data
for
name
,
x
in
self
.
experts
.
named_parameters
()
if
name
not
in
[
"correction_bias"
]
]
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
router
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
ans
=
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
ans
class
GptOssAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
head_dim
:
Optional
[
int
]
=
None
,
rms_norm_eps
:
float
=
1e-06
,
attention_bias
:
bool
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
sliding_window_size
:
int
=
-
1
,
# if -1, normal attention, else, window attention.
layer_type
:
str
=
""
,
params_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
sliding_window_size
=
sliding_window_size
attn_tp_rank
=
get_attention_tp_rank
()
attn_tp_size
=
get_attention_tp_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
attn_tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
attn_tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
attn_tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
attn_tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
attn_tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
attn_tp_size
)
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
attention_bias
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
self
.
sinks
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_heads
,
dtype
=
params_dtype
),
requires_grad
=
False
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
attention_bias
,
quant_config
=
quant_config
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
reduce_results
=
False
,
params_dtype
=
params_dtype
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
assert
layer_type
in
{
"sliding_attention"
,
"full_attention"
}
use_sliding_window
=
layer_type
==
"sliding_attention"
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
sliding_window_size
=
(
sliding_window_size
if
use_sliding_window
else
-
1
),
)
self
.
layer_id
=
layer_id
def
forward_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
if
hidden_states
.
shape
[
0
]
==
0
:
return
hidden_states
,
forward_batch
,
None
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
inner_state
=
q
,
k
,
v
,
forward_batch
return
None
,
forward_batch
,
inner_state
def
forward_core
(
self
,
intermediate_state
):
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
if
inner_state
is
None
:
return
hidden_states
attn_output
=
self
.
attn
(
*
inner_state
,
sk
=
self
.
sinks
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
s
=
self
.
forward_prepare
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
return
self
.
forward_core
(
s
)
class
GptOssDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GptOssConfig
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
sliding_window_size
:
int
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
rms_norm_eps
=
config
.
rms_norm_eps
attention_bias
=
config
.
attention_bias
if
sliding_window_size
is
None
:
self
.
sliding_window_size
=
get_attention_sliding_window_size
(
self
.
config
)
else
:
self
.
sliding_window_size
=
sliding_window_size
self
.
self_attn
=
GptOssAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
head_dim
=
head_dim
,
rms_norm_eps
=
rms_norm_eps
,
attention_bias
=
attention_bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
sliding_window_size
=
self
.
sliding_window_size
,
layer_type
=
config
.
layer_types
[
layer_id
],
params_dtype
=
config
.
torch_dtype
,
)
self
.
layer_id
=
layer_id
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
local_dp_size
=
get_local_attention_dp_size
()
# GptOss all layers are sparse and have no nextn now
self
.
is_layer_sparse
=
True
is_previous_layer_sparse
=
True
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
layer_id
=
layer_id
,
num_layers
=
config
.
num_hidden_layers
,
is_layer_sparse
=
self
.
is_layer_sparse
,
is_previous_layer_sparse
=
is_previous_layer_sparse
,
)
if
self
.
is_layer_sparse
:
self
.
mlp
=
GptOssSparseMoeBlock
(
layer_id
=
self
.
layer_id
,
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
else
:
raise
NotImplementedError
(
"Dense MLP is not implemented for GptOssDecoderLayer. "
"Please use GptOssSparseMoeBlock instead."
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
layer_communicator
=
LayerCommunicator
(
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
hidden_states
,
residual
,
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
return
hidden_states
,
residual
class
GptOssModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
decoder_layer_type
:
type
[
nn
.
Module
]
=
GptOssDecoderLayer
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
# Use the provided decoder layer type or default to GptOssDecoderLayer
decoder_layer_type
=
decoder_layer_type
or
GptOssDecoderLayer
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
decoder_layer_type
(
layer_id
=
idx
,
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
layers_to_capture
=
[]
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
assert
pp_proxy_tensors
is
not
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
aux_hidden_states
=
[]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
if
i
in
self
.
layers_to_capture
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
return
hidden_states
,
aux_hidden_states
class
GptOssForCausalLM
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
config
:
GptOssConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
GptOssModel
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args_dict
[
"enable_dp_lm_head"
],
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
capture_aux_hidden_states
=
False
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
,
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
_get_default_weight_mapping
(
self
):
"""Generate default weight name mapping for GptOss safetensors."""
weight_mapping
=
{}
# Map router weights to gate
weight_mapping
[
"embedding.weight"
]
=
"model.embed_tokens.weight"
weight_mapping
[
"unembedding.weight"
]
=
"lm_head.weight"
weight_mapping
[
"norm.scale"
]
=
"model.norm.weight"
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
weight_mapping
[
f
"block.
{
layer_id
}
.attn.q_proj.weight"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.q_proj.weight"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.q_proj.bias"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.q_proj.bias"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.k_proj.weight"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.k_proj.weight"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.k_proj.bias"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.k_proj.bias"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.v_proj.weight"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.v_proj.weight"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.v_proj.bias"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.v_proj.bias"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.out.weight"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.o_proj.weight"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.out.bias"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.o_proj.bias"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.sinks"
]
=
(
f
"model.layers.
{
layer_id
}
.self_attn.sinks"
)
weight_mapping
[
f
"block.
{
layer_id
}
.attn.norm.scale"
]
=
(
f
"model.layers.
{
layer_id
}
.input_layernorm.weight"
)
weight_mapping
[
f
"block.
{
layer_id
}
.mlp.gate.weight"
]
=
(
f
"model.layers.
{
layer_id
}
.mlp.router.weight"
)
weight_mapping
[
f
"block.
{
layer_id
}
.mlp.gate.bias"
]
=
(
f
"model.layers.
{
layer_id
}
.mlp.router.bias"
)
weight_mapping
[
f
"block.
{
layer_id
}
.mlp.norm.scale"
]
=
(
f
"model.layers.
{
layer_id
}
.post_attention_layernorm.weight"
)
weight_mapping
[
f
"block.
{
layer_id
}
.mlp.experts.gate_up_proj"
]
=
(
f
"model.layers.
{
layer_id
}
.mlp.experts.gate_up_proj"
)
weight_mapping
[
f
"block.
{
layer_id
}
.mlp.gate_up_proj_bias"
]
=
(
f
"model.layers.
{
layer_id
}
.mlp.experts.gate_up_proj_bias"
)
weight_mapping
[
f
"block.
{
layer_id
}
.mlp.down_proj"
]
=
(
f
"model.layers.
{
layer_id
}
.mlp.experts.mlp2_weight"
)
weight_mapping
[
f
"block.
{
layer_id
}
.mlp.down_proj_bias"
]
=
(
f
"model.layers.
{
layer_id
}
.mlp.experts.mlp2_bias"
)
return
weight_mapping
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
:
bool
=
False
,
weight_name_mapping
:
dict
=
None
,
):
tp_rank
=
get_tensor_model_parallel_rank
()
if
is_nextn
:
logging
.
warning
(
"Loading weights for nextn is currently not supported in GptOssForCausalLM. "
)
return
weights
=
_canonicalize_weights
(
self
.
config
,
weights
)
weights
=
sorted
(
weights
,
key
=
lambda
x
:
x
[
0
])
# Sort by name for consistency
new_weights
=
[]
for
name
,
p
in
weights
:
if
"qkv.weight"
in
name
:
q_proj
,
k_proj
,
v_proj
=
p
.
split
(
[
self
.
config
.
num_attention_heads
*
self
.
config
.
head_dim
,
self
.
config
.
num_key_value_heads
*
self
.
config
.
head_dim
,
self
.
config
.
num_key_value_heads
*
self
.
config
.
head_dim
,
],
dim
=
0
,
)
new_weights
.
append
(
(
f
"
{
name
.
replace
(
'qkv.weight'
,
'q_proj.weight'
)
}
"
,
q_proj
)
)
new_weights
.
append
(
(
f
"
{
name
.
replace
(
'qkv.weight'
,
'k_proj.weight'
)
}
"
,
k_proj
)
)
new_weights
.
append
(
(
f
"
{
name
.
replace
(
'qkv.weight'
,
'v_proj.weight'
)
}
"
,
v_proj
)
)
elif
"qkv.bias"
in
name
:
q_bias
,
k_bias
,
v_bias
=
p
.
split
(
[
self
.
config
.
num_attention_heads
*
self
.
config
.
head_dim
,
self
.
config
.
num_key_value_heads
*
self
.
config
.
head_dim
,
self
.
config
.
num_key_value_heads
*
self
.
config
.
head_dim
,
],
dim
=
0
,
)
new_weights
.
append
(
(
f
"
{
name
.
replace
(
'qkv.bias'
,
'q_proj.bias'
)
}
"
,
q_bias
)
)
new_weights
.
append
(
(
f
"
{
name
.
replace
(
'qkv.bias'
,
'k_proj.bias'
)
}
"
,
k_bias
)
)
new_weights
.
append
(
(
f
"
{
name
.
replace
(
'qkv.bias'
,
'v_proj.bias'
)
}
"
,
v_bias
)
)
else
:
new_weights
.
append
((
name
,
p
))
weights
=
new_weights
# Use provided weight name mapping if available, otherwise use default
if
weight_name_mapping
is
None
:
weight_name_mapping
=
self
.
_get_default_weight_mapping
()
else
:
# Merge with default mapping
default_mapping
=
self
.
_get_default_weight_mapping
()
default_mapping
.
update
(
weight_name_mapping
)
weight_name_mapping
=
default_mapping
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
expert_params_mapping
=
get_moe_impl_class
().
make_expert_params_mapping_fused
(
ckpt_gate_up_proj_name
=
"gate_up_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_gate_up_proj_bias_name
=
"gate_up_proj_bias"
,
ckpt_down_proj_bias_name
=
"down_proj_bias"
,
)
params_dict
=
dict
(
self
.
named_parameters
())
params_checker
=
{
k
:
False
for
k
,
v
in
params_dict
.
items
()}
for
name
,
loaded_weight
in
weights
:
loaded_weight
=
_WeightCreator
.
maybe_materialize
(
loaded_weight
)
# Apply weight name mapping if provided
if
weight_name_mapping
and
name
in
weight_name_mapping
:
name
=
weight_name_mapping
[
name
]
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
"mlp.experts"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
params_checker
[
name
]
=
True
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
if
"bias"
not
in
name
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
if
"w2_weight_bias"
in
name
and
get_moe_tensor_parallel_rank
()
!=
0
:
loaded_weight
=
loaded_weight
.
zero_
()
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
)
params_checker
[
name
]
=
True
break
else
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
if
"sinks"
in
name
:
start
=
tp_rank
*
param
.
numel
()
param
.
data
.
copy_
(
loaded_weight
[
start
:
start
+
param
.
numel
()]
)
else
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
params_checker
[
name
]
=
True
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
not_loaded_params
=
[
k
for
k
,
v
in
params_checker
.
items
()
if
not
v
]
if
tp_rank
==
0
:
if
len
(
not_loaded_params
)
>
0
:
raise
Exception
(
f
"Not all parameters loaded:
{
not_loaded_params
}
"
)
else
:
logging
.
info
(
"All parameters loaded successfully."
)
self
.
routed_experts_weights_of_layer
=
{
layer_id
:
self
.
model
.
layers
[
layer_id
].
mlp
.
get_moe_weights
()
for
layer_id
in
range
(
self
.
start_layer
,
self
.
end_layer
)
if
isinstance
(
self
.
model
.
layers
[
layer_id
].
mlp
,
GptOssSparseMoeBlock
)
}
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
del
self
.
lm_head
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
if
not
self
.
pp_group
.
is_last_rank
:
return
if
layer_ids
is
None
:
self
.
capture_aux_hidden_states
=
True
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
]
else
:
self
.
capture_aux_hidden_states
=
True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
@
classmethod
def
get_model_config_for_expert_location
(
cls
,
config
):
return
ModelConfigForExpertLocation
(
num_layers
=
config
.
num_hidden_layers
,
num_logical_experts
=
config
.
num_local_experts
,
num_groups
=
None
,
)
def
get_attention_sliding_window_size
(
self
):
return
get_attention_sliding_window_size
(
self
.
config
)
def
_canonicalize_weights
(
config
,
weights_in
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
weights_out_dict
=
dict
(
weights_in
)
for
layer_id
in
range
(
config
.
num_hidden_layers
):
for
name_chunk
in
[
"mlp1_weight"
,
"mlp2_weight"
]:
name_prefix
=
f
"block.
{
layer_id
}
.mlp.
{
name_chunk
}
"
w_blocks
=
weights_out_dict
.
pop
(
f
"
{
name_prefix
}
.blocks"
,
None
)
w_scales
=
weights_out_dict
.
pop
(
f
"
{
name_prefix
}
.scales"
,
None
)
if
w_blocks
is
not
None
:
weights_out_dict
[
name_prefix
]
=
_WeightCreator
(
partial
(
_dequant_mlp_weight
,
debug_name
=
name_prefix
,
w_blocks
=
w_blocks
,
w_scales
=
w_scales
,
)
)
return
list
(
weights_out_dict
.
items
())
def
_dequant_mlp_weight
(
debug_name
,
w_blocks
,
w_scales
):
if
get_tensor_model_parallel_rank
()
==
0
:
logger
.
info
(
f
"Dequantize
{
debug_name
}
start"
)
original_device
=
w_blocks
.
device
w_blocks
=
w_blocks
.
cuda
()
w_scales
=
w_scales
.
cuda
()
w_bf16
=
dequant_mxfp4
(
w_block
=
w_blocks
,
w_scale
=
w_scales
,
out_dtype
=
torch
.
bfloat16
)
w_bf16
=
w_bf16
.
transpose
(
-
2
,
-
1
).
contiguous
()
if
get_tensor_model_parallel_rank
()
==
0
:
logger
.
info
(
f
"Dequantize
{
debug_name
}
end
{
w_blocks
.
shape
=
}
{
w_scales
.
shape
=
}
{
w_bf16
.
shape
=
}
"
)
return
w_bf16
.
to
(
original_device
)
class
_WeightCreator
:
def
__init__
(
self
,
fn
):
self
.
_fn
=
fn
@
staticmethod
def
maybe_materialize
(
obj
):
if
isinstance
(
obj
,
_WeightCreator
):
output
=
obj
.
_fn
()
obj
.
_fn
=
None
return
output
return
obj
EntryClass
=
GptOssForCausalLM
python/sglang/srt/server_args.py
View file @
c1d2061f
...
...
@@ -457,6 +457,10 @@ class ServerArgs:
raise
ValueError
(
"trtllm_mla backend does not support speculative decoding yet."
)
model_arch
=
self
.
get_hf_config
().
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
self
.
attention_backend
=
"triton"
self
.
enable_triton_kernel_moe
=
True
# Set page size
if
self
.
page_size
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment