Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
46d8fb1c
Unverified
Commit
46d8fb1c
authored
Sep 12, 2025
by
EduardDurech
Committed by
GitHub
Sep 11, 2025
Browse files
model: support Apertus (#9774)
parent
c7e85f53
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
801 additions
and
0 deletions
+801
-0
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+110
-0
python/sglang/srt/models/apertus.py
python/sglang/srt/models/apertus.py
+686
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+5
-0
No files found.
python/sglang/srt/layers/activation.py
View file @
46d8fb1c
...
...
@@ -171,6 +171,115 @@ class QuickGELU(CustomOp):
return
torch_npu
.
npu_fast_gelu
(
x
)
class
XIELU
(
CustomOp
):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""
def
__init__
(
self
,
alpha_p_init
:
float
=
0.8
,
alpha_n_init
:
float
=
0.8
,
beta
:
float
=
0.5
,
eps
:
float
=
-
1e-6
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
with_vector_loads
:
bool
=
False
,
):
super
().
__init__
()
self
.
alpha_p
=
nn
.
Parameter
(
torch
.
log
(
torch
.
exp
(
torch
.
tensor
(
alpha_p_init
,
dtype
=
dtype
))
-
1
).
unsqueeze
(
0
)
)
self
.
alpha_n
=
nn
.
Parameter
(
torch
.
log
(
torch
.
exp
(
torch
.
tensor
(
alpha_n_init
-
beta
,
dtype
=
dtype
))
-
1
).
unsqueeze
(
0
)
)
self
.
register_buffer
(
"beta"
,
torch
.
tensor
(
beta
,
dtype
=
dtype
))
self
.
register_buffer
(
"eps"
,
torch
.
tensor
(
eps
,
dtype
=
dtype
))
self
.
with_vector_loads
=
with_vector_loads
# Temporary until xIELU CUDA fully implemented
self
.
_beta_scalar
=
float
(
self
.
beta
.
detach
().
cpu
().
float
().
item
())
self
.
_eps_scalar
=
float
(
self
.
eps
.
detach
().
cpu
().
float
().
item
())
self
.
_xielu_cuda_obj
=
None
try
:
import
xielu.ops
# noqa: F401
self
.
_xielu_cuda_obj
=
torch
.
classes
.
xielu
.
XIELU
()
msg
=
"Using experimental xIELU CUDA."
try
:
from
torch._dynamo
import
allow_in_graph
self
.
_xielu_cuda_fn
=
allow_in_graph
(
self
.
_xielu_cuda
)
msg
+=
" Enabled torch._dynamo for xIELU CUDA."
except
Exception
as
err
:
msg
+=
(
f
" Could not enable torch._dynamo for xIELU (
{
err
}
) - "
"this may result in slower performance."
)
self
.
_xielu_cuda_fn
=
self
.
_xielu_cuda
logger
.
warning_once
(
msg
)
except
Exception
as
err
:
logger
.
warning_once
(
"CUDA-fused xIELU not available (%s) –"
" falling back to a Python version.
\n
"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`"
,
str
(
err
),
)
def
_xielu_python
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
alpha_p
=
nn
.
functional
.
softplus
(
self
.
alpha_p
)
alpha_n
=
self
.
beta
+
nn
.
functional
.
softplus
(
self
.
alpha_n
)
return
torch
.
where
(
x
>
0
,
alpha_p
*
x
*
x
+
self
.
beta
*
x
,
(
torch
.
expm1
(
torch
.
min
(
x
,
self
.
eps
))
-
x
)
*
alpha_n
+
self
.
beta
*
x
,
)
def
_xielu_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert
self
.
_xielu_cuda_obj
is
not
None
,
"XIELU CUDA object must not be None"
original_shape
=
x
.
shape
# CUDA kernel expects 3D tensors, reshape if needed
while
x
.
dim
()
<
3
:
x
=
x
.
unsqueeze
(
0
)
if
x
.
dim
()
>
3
:
x
=
x
.
view
(
-
1
,
1
,
x
.
size
(
-
1
))
if
original_shape
!=
x
.
shape
:
logger
.
warning_once
(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s).
\n
"
"Note: For SGLang this may be expected if sending"
"[B*S,D] instead of [B,S,D]."
,
original_shape
,
x
.
shape
,
)
result
=
self
.
_xielu_cuda_obj
.
forward
(
x
,
self
.
alpha_p
,
self
.
alpha_n
,
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
self
.
_beta_scalar
,
self
.
_eps_scalar
,
self
.
with_vector_loads
,
)
return
result
.
view
(
original_shape
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
_xielu_cuda_obj
is
not
None
and
input
.
is_cuda
:
if
not
torch
.
_dynamo
.
is_compiling
():
return
self
.
_xielu_cuda_fn
(
input
)
else
:
logger
.
warning_once
(
"torch._dynamo is compiling, using Python version of xIELU."
)
return
self
.
_xielu_python
(
input
)
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
...
...
@@ -218,6 +327,7 @@ _ACTIVATION_REGISTRY = {
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_new"
:
NewGELU
(),
"relu2"
:
ReLU2
(),
"xielu"
:
XIELU
(),
}
...
...
python/sglang/srt/models/apertus.py
0 → 100644
View file @
46d8fb1c
This diff is collapsed.
Click to expand it.
test/srt/models/test_generation_models.py
View file @
46d8fb1c
...
...
@@ -90,6 +90,11 @@ ALL_MODELS = [
trust_remote_code
=
True
,
skip_long_prompt
=
True
,
),
ModelCase
(
"swiss-ai/Apertus-8B"
,
trust_remote_code
=
True
,
skip_long_prompt
=
True
,
),
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment