Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4d4061b6
Unverified
Commit
4d4061b6
authored
Aug 17, 2025
by
Jee Jee Li
Committed by
GitHub
Aug 17, 2025
Browse files
[Kernel] Add cuda kernel for gpt_oss activation (#22951)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
87f48623
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
157 additions
and
42 deletions
+157
-42
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+59
-0
csrc/ops.h
csrc/ops.h
+2
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-0
tests/kernels/core/test_activation.py
tests/kernels/core/test_activation.py
+39
-6
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+38
-3
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+5
-17
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-13
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+2
-2
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+1
-1
No files found.
csrc/activation_kernels.cu
View file @
4d4061b6
...
@@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
...
@@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
}
}
}
}
template
<
typename
T
>
__device__
__forceinline__
T
swigluoai_and_mul
(
const
T
&
gate
,
const
T
&
up
,
float
alpha
,
float
limit
)
{
// clamp gate: min=None, max=limit
const
float
gate_f
=
(
float
)
gate
;
const
float
clamped_gate
=
gate_f
>
limit
?
limit
:
gate_f
;
// clamp up: min=-limit, max=limit
const
float
up_f
=
(
float
)
up
;
const
float
clamped_up
=
up_f
>
limit
?
limit
:
(
up_f
<
-
limit
?
-
limit
:
up_f
);
// glu = gate * sigmoid(gate * alpha)
const
float
sigmoid_val
=
1.0
f
/
(
1.0
f
+
expf
(
-
clamped_gate
*
alpha
));
const
float
glu
=
clamped_gate
*
sigmoid_val
;
// (up + 1) * glu
return
(
T
)((
clamped_up
+
1.0
f
)
*
glu
);
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
,
const
scalar_t
&
,
const
float
,
const
float
)>
__global__
void
swigluoai_and_mul_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
,
const
float
alpha
,
const
float
limit
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
// TODO: Vectorize loads and stores.
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
// gate = x[..., ::2] (even indices)
const
scalar_t
gate
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
2
*
idx
]);
// up = x[..., 1::2] (odd indices)
const
scalar_t
up
=
VLLM_LDG
(
&
input
[
token_idx
*
2
*
d
+
2
*
idx
+
1
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
gate
,
up
,
alpha
,
limit
);
}
}
}
// namespace vllm
}
// namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
...
@@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
...
@@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
PARAM); \
PARAM); \
});
});
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, ALPHA, \
LIMIT); \
});
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d],
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
// [..., d],
torch
::
Tensor
&
input
,
// [..., 2 * d]
torch
::
Tensor
&
input
,
// [..., 2 * d]
double
threshold
)
{
double
threshold
)
{
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM
(
vllm
::
fatrelu_kernel
,
threshold
);
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM
(
vllm
::
fatrelu_kernel
,
threshold
);
}
}
void
swigluoai_and_mul
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., 2 * d]
double
alpha
,
double
limit
)
{
LAUNCH_SIGLUOAI_AND_MUL
(
vllm
::
swigluoai_and_mul
,
alpha
,
limit
);
}
namespace
vllm
{
namespace
vllm
{
// Element-wise activation kernel template.
// Element-wise activation kernel template.
...
...
csrc/ops.h
View file @
4d4061b6
...
@@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
...
@@ -138,6 +138,8 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
void
fatrelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
double
threshold
);
double
threshold
);
void
swigluoai_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
double
alpha
=
1.702
,
double
limit
=
7.0
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/torch_bindings.cpp
View file @
4d4061b6
...
@@ -130,6 +130,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -130,6 +130,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"
);
ops
.
def
(
"fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"
);
ops
.
impl
(
"fatrelu_and_mul"
,
torch
::
kCUDA
,
&
fatrelu_and_mul
);
ops
.
impl
(
"fatrelu_and_mul"
,
torch
::
kCUDA
,
&
fatrelu_and_mul
);
ops
.
def
(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
"limit=7.0) "
"-> ()"
);
ops
.
impl
(
"swigluoai_and_mul"
,
torch
::
kCUDA
,
&
swigluoai_and_mul
);
// GELU implementation used in GPT-2.
// GELU implementation used in GPT-2.
ops
.
def
(
"gelu_new(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"gelu_new(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_new"
,
torch
::
kCUDA
,
&
gelu_new
);
ops
.
impl
(
"gelu_new"
,
torch
::
kCUDA
,
&
gelu_new
);
...
...
tests/kernels/core/test_activation.py
View file @
4d4061b6
...
@@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck
...
@@ -11,7 +11,7 @@ from tests.kernels.utils import opcheck
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
FatreluAndMul
,
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
FatreluAndMul
,
GeluAndMul
,
MulAndSilu
,
GeluAndMul
,
MulAndSilu
,
NewGELU
,
QuickGELU
,
NewGELU
,
QuickGELU
,
SiluAndMul
)
SiluAndMul
,
SwigluOAIAndMul
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -25,7 +25,15 @@ CUDA_DEVICES = [
...
@@ -25,7 +25,15 @@ CUDA_DEVICES = [
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"activation"
,
"activation"
,
[
"silu_and_mul"
,
"mul_and_silu"
,
"gelu"
,
"gelu_tanh"
,
"fatrelu"
])
[
"silu_and_mul"
,
"mul_and_silu"
,
"gelu"
,
"gelu_tanh"
,
"fatrelu"
,
"swigluoai_and_mul"
,
],
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
@@ -59,18 +67,43 @@ def test_act_and_mul(
...
@@ -59,18 +67,43 @@ def test_act_and_mul(
threshold
=
random
.
uniform
(
0
,
1
)
threshold
=
random
.
uniform
(
0
,
1
)
layer
=
FatreluAndMul
(
threshold
)
layer
=
FatreluAndMul
(
threshold
)
fn
=
torch
.
ops
.
_C
.
fatrelu_and_mul
fn
=
torch
.
ops
.
_C
.
fatrelu_and_mul
elif
activation
==
"swigluoai_and_mul"
:
layer
=
SwigluOAIAndMul
()
fn
=
torch
.
ops
.
_C
.
swigluoai_and_mul
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
if
activation
==
"swigluoai_and_mul"
:
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
rtol
=
{
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
#For fp16, change the relative tolerance from 1e-3 to 2e-3
torch
.
float16
:
2e-3
,
torch
.
bfloat16
:
2e-2
,
torch
.
float
:
1.3e-6
}
def
_get_rtol
(
output
)
->
float
:
return
rtol
[
output
.
dtype
]
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
_get_rtol
(
out
))
else
:
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
activation
==
"fatrelu"
:
if
activation
==
"fatrelu"
:
opcheck
(
fn
,
(
out
,
x
,
threshold
))
opcheck
(
fn
,
(
out
,
x
,
threshold
))
elif
activation
==
"swigluoai_and_mul"
:
opcheck
(
fn
,
(
out
,
x
,
layer
.
alpha
,
layer
.
limit
))
else
:
else
:
opcheck
(
fn
,
(
out
,
x
))
opcheck
(
fn
,
(
out
,
x
))
...
...
vllm/model_executor/layers/activation.py
View file @
4d4061b6
...
@@ -239,6 +239,35 @@ class GeluAndMul(CustomOp):
...
@@ -239,6 +239,35 @@ class GeluAndMul(CustomOp):
return
f
'approximate=
{
repr
(
self
.
approximate
)
}
'
return
f
'approximate=
{
repr
(
self
.
approximate
)
}
'
@
CustomOp
.
register
(
"swigluoai_and_mul"
)
class
SwigluOAIAndMul
(
CustomOp
):
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
def
__init__
(
self
,
alpha
:
float
=
1.702
,
limit
:
float
=
7.0
):
super
().
__init__
()
self
.
alpha
=
alpha
self
.
limit
=
limit
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
gate
,
up
=
x
[...,
::
2
],
x
[...,
1
::
2
]
gate
=
gate
.
clamp
(
min
=
None
,
max
=
self
.
limit
)
up
=
up
.
clamp
(
min
=-
self
.
limit
,
max
=
self
.
limit
)
glu
=
gate
*
torch
.
sigmoid
(
gate
*
self
.
alpha
)
gated_output
=
(
up
+
1
)
*
glu
return
gated_output
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
torch
.
ops
.
_C
.
swigluoai_and_mul
(
out
,
x
,
self
.
alpha
,
self
.
limit
)
return
out
def
extra_repr
(
self
)
->
str
:
return
f
"alpha=
{
repr
(
self
.
alpha
)
}
, limit=
{
repr
(
self
.
limit
)
}
"
@
CustomOp
.
register
(
"gelu_new"
)
@
CustomOp
.
register
(
"gelu_new"
)
class
NewGELU
(
CustomOp
):
class
NewGELU
(
CustomOp
):
...
@@ -330,6 +359,7 @@ class ReLUSquaredActivation(CustomOp):
...
@@ -330,6 +359,7 @@ class ReLUSquaredActivation(CustomOp):
return
torch
.
square
(
F
.
relu
(
x
))
return
torch
.
square
(
F
.
relu
(
x
))
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
#TODO : implement cuda kenrels
return
self
.
forward_native
(
x
)
return
self
.
forward_native
(
x
)
...
@@ -406,9 +436,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
...
@@ -406,9 +436,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
_ACTIVATION_AND_MUL_REGISTRY
=
LazyDict
({
_ACTIVATION_AND_MUL_REGISTRY
=
LazyDict
({
"gelu"
:
lambda
:
GeluAndMul
(),
"gelu"
:
"silu"
:
lambda
:
SiluAndMul
(),
lambda
:
GeluAndMul
(),
"geglu"
:
lambda
:
GeluAndMul
(),
"silu"
:
lambda
:
SiluAndMul
(),
"geglu"
:
lambda
:
GeluAndMul
(),
"swigluoai"
:
lambda
*
args
,
**
kwargs
:
SwigluOAIAndMul
(
*
args
,
**
kwargs
),
})
})
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
4d4061b6
...
@@ -161,25 +161,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
...
@@ -161,25 +161,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
if
activation
==
"silu"
:
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
elif
activation
==
"swiglu_oai"
:
elif
activation
==
"swigluoai"
:
# NOTE: in gpt-oss, the gate_proj and up_proj is interleaved
# alpha = 1.702, limit = 7.0
# - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2]
torch
.
ops
.
_C
.
swigluoai_and_mul
(
intermediate_cache2
,
# - origin: gate, up = gate_up[..., :N], gate_up[..., N:]
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
@
torch
.
compile
(
dynamic
=
True
)
def
swiglu_oai
(
gate_up
):
alpha
=
1.702
limit
=
7.0
gate
,
up
=
gate_up
[...,
::
2
],
gate_up
[...,
1
::
2
]
gate
=
gate
.
clamp
(
min
=
None
,
max
=
limit
)
up
=
up
.
clamp
(
min
=-
limit
,
max
=
limit
)
glu
=
gate
*
torch
.
sigmoid
(
gate
*
alpha
)
return
(
up
+
1
)
*
glu
intermediate_cache2
=
swiglu_oai
(
intermediate_cache1
)
else
:
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
activation
}
. "
"Only silu and swiglu
_
oai activations are supported."
)
"Only silu and swigluoai activations are supported."
)
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
intermediate_cache3
.
zero_
()
intermediate_cache3
.
zero_
()
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
4d4061b6
...
@@ -1621,17 +1621,6 @@ def fused_experts_impl(
...
@@ -1621,17 +1621,6 @@ def fused_experts_impl(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
B_bias
=
w1_bias
)
B_bias
=
w1_bias
)
# TODO fused kernel
def
swiglu_oai
(
gate_up
):
alpha
=
1.702
limit
=
7.0
gate
,
up
=
gate_up
[...,
::
2
],
gate_up
[...,
1
::
2
]
gate
=
gate
.
clamp
(
min
=
None
,
max
=
limit
)
up
=
up
.
clamp
(
min
=-
limit
,
max
=
limit
)
glu
=
gate
*
torch
.
sigmoid
(
gate
*
alpha
)
gated_output
=
(
up
+
1
)
*
glu
return
gated_output
# Activation function with multiplication
# Activation function with multiplication
if
activation
==
"silu"
and
is_act_and_mul
:
if
activation
==
"silu"
and
is_act_and_mul
:
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
...
@@ -1639,13 +1628,16 @@ def fused_experts_impl(
...
@@ -1639,13 +1628,16 @@ def fused_experts_impl(
elif
activation
==
"gelu"
and
is_act_and_mul
:
elif
activation
==
"gelu"
and
is_act_and_mul
:
torch
.
ops
.
_C
.
gelu_and_mul
(
intermediate_cache2
,
torch
.
ops
.
_C
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"swigluoai"
and
is_act_and_mul
:
# alpha = 1.702, limit = 7.0
torch
.
ops
.
_C
.
swigluoai_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
# Activation function without multiplication
# Activation function without multiplication
elif
activation
==
"silu"
:
elif
activation
==
"silu"
:
intermediate_cache2
=
F
.
silu
(
intermediate_cache1
.
view
(
-
1
,
N
))
intermediate_cache2
=
F
.
silu
(
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
elif
activation
==
"gelu"
:
intermediate_cache2
=
F
.
gelu
(
intermediate_cache1
.
view
(
-
1
,
N
))
intermediate_cache2
=
F
.
gelu
(
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"swiglu_oai"
:
intermediate_cache2
=
swiglu_oai
(
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
, "
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
, "
f
"with is_act_and_mul=
{
is_act_and_mul
}
."
)
f
"with is_act_and_mul=
{
is_act_and_mul
}
."
)
...
...
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
View file @
4d4061b6
...
@@ -61,14 +61,14 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
...
@@ -61,14 +61,14 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
activation
:
str
=
"swiglu
_
oai"
,
activation
:
str
=
"swigluoai"
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
):
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
):
return
not
(
use_grouped_topk
or
topk_group
or
num_expert_group
return
not
(
use_grouped_topk
or
topk_group
or
num_expert_group
or
expert_map
or
custom_routing_function
or
expert_map
or
custom_routing_function
or
e_score_correction_bias
or
apply_router_weight_on_input
or
e_score_correction_bias
or
apply_router_weight_on_input
or
scoring_func
!=
"softmax"
or
activation
!=
"swiglu
_
oai"
or
scoring_func
!=
"softmax"
or
activation
!=
"swigluoai"
or
expert_load_view
or
logical_to_physical_map
or
expert_load_view
or
logical_to_physical_map
or
logical_replica_count
)
or
logical_replica_count
)
...
...
vllm/model_executor/models/gpt_oss.py
View file @
4d4061b6
...
@@ -159,7 +159,7 @@ class MLPBlock(torch.nn.Module):
...
@@ -159,7 +159,7 @@ class MLPBlock(torch.nn.Module):
prefix
=
f
"
{
prefix
}
.experts"
,
prefix
=
f
"
{
prefix
}
.experts"
,
apply_router_weight_on_input
=
False
,
apply_router_weight_on_input
=
False
,
has_bias
=
True
,
has_bias
=
True
,
activation
=
"swiglu
_
oai"
)
activation
=
"swigluoai"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t
=
self
.
norm
(
x
)
t
=
self
.
norm
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment