Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0c5254b8
Unverified
Commit
0c5254b8
authored
Aug 10, 2025
by
Jee Jee Li
Committed by
GitHub
Aug 09, 2025
Browse files
[oss] Init gpt-oss bf16 support (#22508)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
61f67d8a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
342 additions
and
125 deletions
+342
-125
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+5
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+160
-109
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+28
-12
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+149
-3
No files found.
vllm/model_executor/layers/fused_moe/config.py
View file @
0c5254b8
...
...
@@ -324,6 +324,8 @@ class FusedMoEConfig:
max_num_tokens
:
int
=
envs
.
VLLM_MOE_DP_CHUNK_SIZE
has_bias
:
bool
=
False
def
__post_init__
(
self
):
if
self
.
dp_size
>
1
:
logger
.
debug_once
(
"Using FusedMoEConfig::max_num_tokens=%d"
,
...
...
@@ -413,7 +415,8 @@ class FusedMoEConfig:
in_dtype
:
torch
.
dtype
,
max_num_tokens
:
int
=
envs
.
VLLM_MOE_DP_CHUNK_SIZE
,
quant_config
:
Optional
[
Union
[
FusedMoEQuantConfig
,
QuantizationConfig
]]
=
None
QuantizationConfig
]]
=
None
,
has_bias
:
bool
=
False
,
)
->
"FusedMoEConfig"
:
_quant_config
:
Optional
[
FusedMoEQuantConfig
]
=
None
...
...
@@ -482,4 +485,5 @@ class FusedMoEConfig:
in_dtype
=
in_dtype
,
quant_config
=
_quant_config
,
max_num_tokens
=
max_num_tokens
,
has_bias
=
has_bias
,
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
0c5254b8
...
...
@@ -275,6 +275,7 @@ def fused_moe_kernel(
a_ptr
,
b_ptr
,
c_ptr
,
b_bias_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
...
...
@@ -302,6 +303,8 @@ def fused_moe_kernel(
stride_bse
,
stride_bsk
,
stride_bsn
,
stride_bbe
,
# bias expert stride
stride_bbn
,
# bias N stride
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
...
...
@@ -317,6 +320,7 @@ def fused_moe_kernel(
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
...
...
@@ -414,7 +418,10 @@ def fused_moe_kernel(
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
if
HAS_BIAS
:
# bias shape: [num_experts, N]
bias_ptrs
=
b_bias_ptr
+
off_experts
*
stride_bbe
+
offs_bn
*
stride_bbn
bias
=
tl
.
load
(
bias_ptrs
,
mask
=
(
offs_bn
<
N
),
other
=
0.0
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
...
...
@@ -456,7 +463,8 @@ def fused_moe_kernel(
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
HAS_BIAS
:
accumulator
=
accumulator
+
bias
[
None
,
:]
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
...
...
@@ -471,6 +479,7 @@ def fused_moe_kernel(
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
...
...
@@ -499,7 +508,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
)
->
None
:
block_shape
:
Optional
[
list
[
int
]]
=
None
,
B_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -531,7 +541,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
A
.
size
(
0
)
*
top_k
*
config
[
'BLOCK_SIZE_M'
])
grid
=
lambda
META
:
(
triton
.
cdiv
(
EM
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
size
(
1
),
META
[
'BLOCK_SIZE_N'
]),
)
HAS_BIAS
=
B_bias
is
not
None
if
(
use_int8_w8a16
or
use_int4_w4a16
)
and
\
block_shape
is
not
None
and
block_shape
[
1
]
>
0
:
assert
B_scale
is
not
None
and
B_scale
.
ndim
==
3
...
...
@@ -611,6 +621,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
A
,
B
,
C
,
B_bias
,
A_scale
,
B_scale
,
topk_weights
,
...
...
@@ -638,6 +649,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_bias
.
stride
(
0
)
if
B_bias
is
not
None
else
0
,
B_bias
.
stride
(
1
)
if
B_bias
is
not
None
else
0
,
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
...
...
@@ -647,6 +660,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
per_channel_quant
=
per_channel_quant
,
HAS_BIAS
=
HAS_BIAS
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
**
config
,
)
...
...
@@ -1024,40 +1038,43 @@ def inplace_fused_experts(
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
#noqa: UP006
block_shape
:
Optional
[
List
[
int
]]
=
None
,
w1_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
#noqa: UP006
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
is_act_and_mul
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_mxfp4_w4a4
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
inplace_fused_experts_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
is_act_and_mul
:
bool
=
True
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_mxfp4_w4a4
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
)
->
None
:
a2_scale
,
block_shape
,
w1_bias
,
w2_bias
)
def
inplace_fused_experts_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
is_act_and_mul
:
bool
=
True
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_mxfp4_w4a4
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
w1_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
pass
...
...
@@ -1246,36 +1263,38 @@ direct_register_custom_op(
def
outplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
is_act_and_mul
:
bool
=
True
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_mxfp4_w4a4
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
#noqa: UP006
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
is_act_and_mul
:
bool
=
True
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_mxfp4_w4a4
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
#noqa: UP006
w1_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
is_act_and_mul
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_mxfp4_w4a4
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
w1_bias
,
w2_bias
)
def
outplace_fused_experts_fake
(
...
...
@@ -1300,7 +1319,9 @@ def outplace_fused_experts_fake(
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
)
->
torch
.
Tensor
:
block_shape
:
Optional
[
list
[
int
]]
=
None
,
w1_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -1332,33 +1353,34 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
is_act_and_mul
:
bool
=
True
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_mxfp4_w4a4
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
,
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
)
->
torch
.
Tensor
:
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
is_act_and_mul
:
bool
=
True
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_mxfp4_w4a4
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
,
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
,
w1_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases because they only support
...
...
@@ -1423,7 +1445,10 @@ def fused_experts(
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
block_shape
=
block_shape
,
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
)
def
fused_experts_impl
(
...
...
@@ -1451,6 +1476,8 @@ def fused_experts_impl(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
w1_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# Check constraints.
if
use_int4_w4a16
:
...
...
@@ -1591,7 +1618,19 @@ def fused_experts_impl(
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
block_shape
=
block_shape
,
B_bias
=
w1_bias
)
# TODO fused kernel
def
swiglu_oai
(
gate_up
):
alpha
=
1.702
limit
=
7.0
gate
,
up
=
gate_up
[...,
::
2
],
gate_up
[...,
1
::
2
]
gate
=
gate
.
clamp
(
min
=
None
,
max
=
limit
)
up
=
up
.
clamp
(
min
=-
limit
,
max
=
limit
)
glu
=
gate
*
torch
.
sigmoid
(
gate
*
alpha
)
gated_output
=
(
up
+
1
)
*
glu
return
gated_output
# Activation function with multiplication
if
activation
==
"silu"
and
is_act_and_mul
:
...
...
@@ -1605,6 +1644,8 @@ def fused_experts_impl(
intermediate_cache2
=
F
.
silu
(
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
intermediate_cache2
=
F
.
gelu
(
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"swiglu_oai"
:
intermediate_cache2
=
swiglu_oai
(
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
, "
f
"with is_act_and_mul=
{
is_act_and_mul
}
."
)
...
...
@@ -1635,7 +1676,8 @@ def fused_experts_impl(
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
block_shape
=
block_shape
,
B_bias
=
w2_bias
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
...
...
@@ -1672,6 +1714,8 @@ def fused_moe(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
w1_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
...
...
@@ -1766,7 +1810,9 @@ def fused_moe(
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
block_shape
=
block_shape
,
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
)
class
TritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
...
...
@@ -1937,7 +1983,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
block_shape
=
self
.
block_shape
,
B_bias
=
None
# TODO support B_bias
)
self
.
activation
(
activation
,
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
...
...
@@ -1948,26 +1996,29 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache2
,
a2_scale
,
self
.
quant_dtype
,
self
.
per_act_token_quant
,
self
.
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2q_scale
,
w2_scale
,
w2_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
not
apply_router_weight_on_input
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2q_scale
,
w2_scale
,
w2_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
not
apply_router_weight_on_input
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
,
B_bias
=
None
# TODO support B_bias
)
ops
.
moe_sum
(
intermediate_cache3
,
output
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
0c5254b8
...
...
@@ -255,7 +255,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self
.
fused_experts
=
fused_experts
# type: ignore
self
.
topk_indices_dtype
=
None
self
.
moe
=
moe
self
.
has_bias
=
self
.
moe
.
has_bias
self
.
rocm_aiter_moe_enabled
=
is_rocm_aiter_moe_enabled
()
if
self
.
rocm_aiter_moe_enabled
:
from
.rocm_aiter_fused_moe
import
rocm_aiter_fused_experts
...
...
@@ -291,7 +291,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
if
self
.
has_bias
:
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
...
...
@@ -301,6 +308,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
has_bias
:
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
def
_maybe_pad_weight
(
self
,
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Pad the weight tensor. This is an optimization on ROCm platform, which
...
...
@@ -465,6 +479,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w1_bias
=
layer
.
w13_bias
if
self
.
has_bias
else
None
,
w2_bias
=
layer
.
w2_bias
if
self
.
has_bias
else
None
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
...
...
@@ -702,6 +718,7 @@ class FusedMoE(torch.nn.Module):
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
num_redundant_experts
:
int
=
0
,
has_bias
:
bool
=
False
,
):
super
().
__init__
()
if
params_dtype
is
None
:
...
...
@@ -793,16 +810,15 @@ class FusedMoE(torch.nn.Module):
# since model_config is not set in the pytest test.
model_dtype
=
params_dtype
moe
=
FusedMoEConfig
.
make
(
num_experts
=
self
.
global_num_experts
,
experts_per_token
=
top_k
,
hidden_dim
=
hidden_size
,
num_local_experts
=
self
.
local_num_experts
,
moe_parallel_config
=
self
.
moe_parallel_config
,
in_dtype
=
model_dtype
,
max_num_tokens
=
envs
.
VLLM_MOE_DP_CHUNK_SIZE
,
quant_config
=
quant_config
,
)
moe
=
FusedMoEConfig
.
make
(
num_experts
=
self
.
global_num_experts
,
experts_per_token
=
top_k
,
hidden_dim
=
hidden_size
,
num_local_experts
=
self
.
local_num_experts
,
moe_parallel_config
=
self
.
moe_parallel_config
,
in_dtype
=
model_dtype
,
max_num_tokens
=
envs
.
VLLM_MOE_DP_CHUNK_SIZE
,
quant_config
=
quant_config
,
has_bias
=
has_bias
)
self
.
moe_config
=
moe
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/models/gpt_oss.py
View file @
0c5254b8
...
...
@@ -160,7 +160,9 @@ class MLPBlock(torch.nn.Module):
renormalize
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
apply_router_weight_on_input
=
False
)
apply_router_weight_on_input
=
False
,
has_bias
=
True
,
activation
=
"swiglu_oai"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t
=
self
.
norm
(
x
)
...
...
@@ -262,8 +264,8 @@ class GptOssForCausalLM(nn.Module):
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
_
load_weights
_mxfp4
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
rename_mapping
=
{
"self_attn"
:
"attn"
,
"input_layernorm.weight"
:
"attn.norm.weight"
,
...
...
@@ -469,3 +471,147 @@ class GptOssForCausalLM(nn.Module):
loaded_params
.
add
(
renamed_name
)
return
loaded_params
def
_load_weights_other
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
rename_mapping
=
{
"self_attn"
:
"attn"
,
"input_layernorm.weight"
:
"attn.norm.weight"
,
"post_attention_layernorm.weight"
:
"mlp.norm.weight"
,
"embed_tokens"
:
"embedding"
,
}
def
maybe_rename
(
name
:
str
)
->
str
:
for
remap_name
,
new_name
in
rename_mapping
.
items
():
if
remap_name
in
name
:
return
name
.
replace
(
remap_name
,
new_name
)
return
name
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
intermediate_size
=
self
.
model_config
.
intermediate_size
per_rank_intermediate_size
=
cdiv
(
intermediate_size
,
tp_size
)
# Calculate common slicing bounds for current rank
tp_rank_start
=
tp_rank
*
per_rank_intermediate_size
tp_rank_end
=
min
((
tp_rank
+
1
)
*
per_rank_intermediate_size
,
intermediate_size
)
# Attention heads per rank
heads_per_rank
=
self
.
model_config
.
num_attention_heads
//
tp_size
head_start
=
tp_rank
*
heads_per_rank
use_ep
=
self
.
vllm_config
.
parallel_config
.
enable_expert_parallel
ep_size
=
get_ep_group
().
world_size
ep_rank
=
get_ep_group
().
rank
num_experts
=
self
.
model_config
.
num_local_experts
experts_per_rank
=
num_experts
//
ep_size
ep_rank_start
=
ep_rank
*
experts_per_rank
ep_rank_end
=
(
ep_rank
+
1
)
*
experts_per_rank
for
name
,
weight
in
weights
:
if
".experts.gate_up_proj"
in
name
and
"bias"
not
in
name
:
# Handle MLP gate and up projection weights
new_name
=
name
.
replace
(
".experts.gate_up_proj"
,
".experts.w13_weight"
)
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[:,
:,
2
*
tp_rank_start
:
2
*
tp_rank_end
]
narrow_weight
=
narrow_weight
.
permute
(
0
,
2
,
1
).
contiguous
()
param
=
params_dict
[
new_name
]
param
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
new_name
)
elif
".experts.down_proj"
in
name
and
"bias"
not
in
name
:
# Handle MLP down projection weights
new_name
=
name
.
replace
(
".experts.down_proj"
,
".experts.w2_weight"
)
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[:,
tp_rank_start
:
tp_rank_end
,
:]
narrow_weight
=
narrow_weight
.
permute
(
0
,
2
,
1
).
contiguous
()
param
=
params_dict
[
new_name
]
param
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
new_name
)
elif
"gate_up_proj_bias"
in
name
:
# Handle MLP gate and up projection biases
new_name
=
name
.
replace
(
"gate_up_proj_bias"
,
"w13_bias"
)
# Extract gate and up projection bias parts
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
]
param
=
params_dict
[
new_name
]
param
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
new_name
)
elif
"down_proj_bias"
in
name
:
# Handle MLP down projection bias
new_name
=
name
.
replace
(
"down_proj_bias"
,
"w2_bias"
)
if
use_ep
:
weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
# (only load on rank 0 to avoid duplication)
if
tp_rank
!=
0
:
weight
.
zero_
()
param
=
params_dict
[
new_name
]
param
.
copy_
(
weight
)
loaded_params
.
add
(
new_name
)
elif
"sinks"
in
name
:
# Handle attention sinks (distributed across ranks)
name
=
name
.
replace
(
"self_attn"
,
"attn"
)
param
=
params_dict
[
name
]
narrow_weight
=
weight
.
narrow
(
0
,
head_start
,
heads_per_rank
)
param
.
data
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
name
)
elif
"q_proj"
in
name
or
"k_proj"
in
name
or
"v_proj"
in
name
:
shard_id
=
(
"q"
if
"q_proj"
in
name
else
"k"
if
"k_proj"
in
name
else
"v"
)
name
=
name
.
replace
(
"self_attn"
,
"attn"
)
param_name
=
name
.
replace
(
f
"
{
shard_id
}
_proj"
,
"qkv"
)
param
=
params_dict
[
param_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight
,
loaded_shard_id
=
shard_id
)
loaded_params
.
add
(
param_name
)
else
:
# Handle all other weights with potential renaming
renamed_name
=
maybe_rename
(
name
)
if
renamed_name
not
in
params_dict
:
continue
param
=
params_dict
[
renamed_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
weight
)
loaded_params
.
add
(
renamed_name
)
return
loaded_params
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
quant_method
=
(
self
.
model_config
.
quantization_config
[
'quant_method'
]
if
hasattr
(
self
.
model_config
,
"quantization_config"
)
else
None
)
if
quant_method
==
"mxfp4"
:
return
self
.
_load_weights_mxfp4
(
weights
)
else
:
return
self
.
_load_weights_other
(
weights
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment