Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
46d8fb1c
Unverified
Commit
46d8fb1c
authored
Sep 12, 2025
by
EduardDurech
Committed by
GitHub
Sep 11, 2025
Browse files
model: support Apertus (#9774)
parent
c7e85f53
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
801 additions
and
0 deletions
+801
-0
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+110
-0
python/sglang/srt/models/apertus.py
python/sglang/srt/models/apertus.py
+686
-0
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+5
-0
No files found.
python/sglang/srt/layers/activation.py
View file @
46d8fb1c
...
@@ -171,6 +171,115 @@ class QuickGELU(CustomOp):
...
@@ -171,6 +171,115 @@ class QuickGELU(CustomOp):
return
torch_npu
.
npu_fast_gelu
(
x
)
return
torch_npu
.
npu_fast_gelu
(
x
)
class
XIELU
(
CustomOp
):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""
def
__init__
(
self
,
alpha_p_init
:
float
=
0.8
,
alpha_n_init
:
float
=
0.8
,
beta
:
float
=
0.5
,
eps
:
float
=
-
1e-6
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
with_vector_loads
:
bool
=
False
,
):
super
().
__init__
()
self
.
alpha_p
=
nn
.
Parameter
(
torch
.
log
(
torch
.
exp
(
torch
.
tensor
(
alpha_p_init
,
dtype
=
dtype
))
-
1
).
unsqueeze
(
0
)
)
self
.
alpha_n
=
nn
.
Parameter
(
torch
.
log
(
torch
.
exp
(
torch
.
tensor
(
alpha_n_init
-
beta
,
dtype
=
dtype
))
-
1
).
unsqueeze
(
0
)
)
self
.
register_buffer
(
"beta"
,
torch
.
tensor
(
beta
,
dtype
=
dtype
))
self
.
register_buffer
(
"eps"
,
torch
.
tensor
(
eps
,
dtype
=
dtype
))
self
.
with_vector_loads
=
with_vector_loads
# Temporary until xIELU CUDA fully implemented
self
.
_beta_scalar
=
float
(
self
.
beta
.
detach
().
cpu
().
float
().
item
())
self
.
_eps_scalar
=
float
(
self
.
eps
.
detach
().
cpu
().
float
().
item
())
self
.
_xielu_cuda_obj
=
None
try
:
import
xielu.ops
# noqa: F401
self
.
_xielu_cuda_obj
=
torch
.
classes
.
xielu
.
XIELU
()
msg
=
"Using experimental xIELU CUDA."
try
:
from
torch._dynamo
import
allow_in_graph
self
.
_xielu_cuda_fn
=
allow_in_graph
(
self
.
_xielu_cuda
)
msg
+=
" Enabled torch._dynamo for xIELU CUDA."
except
Exception
as
err
:
msg
+=
(
f
" Could not enable torch._dynamo for xIELU (
{
err
}
) - "
"this may result in slower performance."
)
self
.
_xielu_cuda_fn
=
self
.
_xielu_cuda
logger
.
warning_once
(
msg
)
except
Exception
as
err
:
logger
.
warning_once
(
"CUDA-fused xIELU not available (%s) –"
" falling back to a Python version.
\n
"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`"
,
str
(
err
),
)
def
_xielu_python
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
alpha_p
=
nn
.
functional
.
softplus
(
self
.
alpha_p
)
alpha_n
=
self
.
beta
+
nn
.
functional
.
softplus
(
self
.
alpha_n
)
return
torch
.
where
(
x
>
0
,
alpha_p
*
x
*
x
+
self
.
beta
*
x
,
(
torch
.
expm1
(
torch
.
min
(
x
,
self
.
eps
))
-
x
)
*
alpha_n
+
self
.
beta
*
x
,
)
def
_xielu_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert
self
.
_xielu_cuda_obj
is
not
None
,
"XIELU CUDA object must not be None"
original_shape
=
x
.
shape
# CUDA kernel expects 3D tensors, reshape if needed
while
x
.
dim
()
<
3
:
x
=
x
.
unsqueeze
(
0
)
if
x
.
dim
()
>
3
:
x
=
x
.
view
(
-
1
,
1
,
x
.
size
(
-
1
))
if
original_shape
!=
x
.
shape
:
logger
.
warning_once
(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s).
\n
"
"Note: For SGLang this may be expected if sending"
"[B*S,D] instead of [B,S,D]."
,
original_shape
,
x
.
shape
,
)
result
=
self
.
_xielu_cuda_obj
.
forward
(
x
,
self
.
alpha_p
,
self
.
alpha_n
,
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
self
.
_beta_scalar
,
self
.
_eps_scalar
,
self
.
with_vector_loads
,
)
return
result
.
view
(
original_shape
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
_xielu_cuda_obj
is
not
None
and
input
.
is_cuda
:
if
not
torch
.
_dynamo
.
is_compiling
():
return
self
.
_xielu_cuda_fn
(
input
)
else
:
logger
.
warning_once
(
"torch._dynamo is compiling, using Python version of xIELU."
)
return
self
.
_xielu_python
(
input
)
class
ScaledActivation
(
nn
.
Module
):
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
"""An activation function with post-scale parameters.
...
@@ -218,6 +327,7 @@ _ACTIVATION_REGISTRY = {
...
@@ -218,6 +327,7 @@ _ACTIVATION_REGISTRY = {
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"gelu_new"
:
NewGELU
(),
"gelu_new"
:
NewGELU
(),
"relu2"
:
ReLU2
(),
"relu2"
:
ReLU2
(),
"xielu"
:
XIELU
(),
}
}
...
...
python/sglang/srt/models/apertus.py
0 → 100644
View file @
46d8fb1c
# Copyright 2025 The SwissAI Initiative
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only Apertus model compatible with HuggingFace weights."""
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
ApertusConfig
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.layers.activation
import
XIELU
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
,
)
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
class
ApertusMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"up_proj"
,
prefix
),
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"down_proj"
,
prefix
),
reduce_results
=
reduce_results
,
)
if
hidden_act
!=
"xielu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only xIELU is supported for now."
)
self
.
act_fn
=
XIELU
()
def
forward
(
self
,
x
,
forward_batch
=
None
,
use_reduce_scatter
:
bool
=
False
,
):
# note: with xielu, there's no gate_proj
x
,
_
=
self
.
up_proj
(
x
)
x
=
self
.
act_fn
(
x
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
use_reduce_scatter
,
)
return
x
class
ApertusAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ApertusConfig
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_is_neox_style
:
bool
=
True
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
bias
:
bool
=
False
,
bias_o_proj
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
1
)
self
.
rotary_dim
=
int
(
partial_rotary_factor
*
self
.
head_dim
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
rope_is_neox_style
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
self
.
q_norm
(
q
.
contiguous
().
view
(
-
1
,
self
.
head_dim
)).
view_as
(
q
)
k
=
self
.
k_norm
(
k
.
contiguous
().
view
(
-
1
,
self
.
head_dim
)).
view_as
(
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
ApertusDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ApertusConfig
,
layer_id
:
int
=
0
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
rope_is_neox_style
=
getattr
(
config
,
"rope_is_neox_style"
,
True
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
# Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
# Support internlm/internlm-7b with bias
attention_bias
=
getattr
(
config
,
"attention_bias"
,
False
)
or
getattr
(
config
,
"bias"
,
False
)
bias_o_proj
=
attention_bias
# support internlm/internlm3-8b with qkv_bias
if
hasattr
(
config
,
"qkv_bias"
):
attention_bias
=
config
.
qkv_bias
self
.
self_attn
=
ApertusAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_is_neox_style
=
rope_is_neox_style
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
bias
=
attention_bias
,
bias_o_proj
=
bias_o_proj
,
)
self
.
mlp
=
ApertusMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
self
.
attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
feedforward_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
attention_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
feedforward_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
ApertusModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ApertusConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
quant_config
=
quant_config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
org_vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
ApertusDecoderLayer
(
config
=
config
,
quant_config
=
quant_config
,
layer_id
=
idx
,
prefix
=
prefix
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
"model.layers"
,
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
layers_to_capture
=
[]
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
PPProxyTensors
]:
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
assert
pp_proxy_tensors
is
not
None
# FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
deferred_norm
=
None
aux_hidden_states
=
[]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
if
i
in
self
.
layers_to_capture
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
return
hidden_states
,
aux_hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
layer_idx
,
scaling_factor
in
kv_cache_scales_loader
(
quantization_param_path
,
tp_rank
,
tp_size
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
__class__
.
model_type
,
):
if
not
isinstance
(
self
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
layers
[
layer_idx
].
self_attn
if
hasattr
(
layer_self_attn
.
attn
,
"k_scale"
):
layer_self_attn
.
attn
.
k_scale
=
scaling_factor
layer_self_attn
.
attn
.
v_scale
=
scaling_factor
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
class
ApertusForCausalLM
(
nn
.
Module
):
# LoRA specific attributes
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
".q_proj"
:
(
".qkv_proj"
,
0
),
".k_proj"
:
(
".qkv_proj"
,
1
),
".v_proj"
:
(
".qkv_proj"
,
2
),
}
def
__init__
(
self
,
config
:
ApertusConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
self
.
_init_model
(
config
,
quant_config
,
add_prefix
(
"model"
,
prefix
))
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
use_attn_tp_group
=
global_server_args_dict
[
"enable_dp_lm_head"
],
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
]
self
.
capture_aux_hidden_states
=
False
def
_init_model
(
self
,
config
:
ApertusConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
return
ApertusModel
(
config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
,
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
return
hidden_states
@
torch
.
no_grad
()
def
forward_split_prefill
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
split_interval
:
Tuple
[
int
,
int
],
# [start, end) 0-based
input_embeds
:
torch
.
Tensor
=
None
,
)
->
Optional
[
LogitsProcessorOutput
]:
start
,
end
=
split_interval
# embed
if
start
==
0
:
if
input_embeds
is
None
:
forward_batch
.
hidden_states
=
self
.
model
.
embed_tokens
(
input_ids
)
else
:
forward_batch
.
hidden_states
=
input_embeds
# decoder layer
for
i
in
range
(
start
,
end
):
layer
=
self
.
model
.
layers
[
i
]
forward_batch
.
hidden_states
,
forward_batch
.
residual
=
layer
(
positions
,
forward_batch
.
hidden_states
,
forward_batch
,
forward_batch
.
residual
,
)
if
end
==
self
.
model
.
config
.
num_hidden_layers
:
# norm
hidden_states
,
_
=
self
.
model
.
norm
(
forward_batch
.
hidden_states
,
forward_batch
.
residual
)
forward_batch
.
hidden_states
=
hidden_states
# logits process
result
=
self
.
logits_processor
(
input_ids
,
forward_batch
.
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
result
=
None
return
result
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
def
get_module_name_from_weight_name
(
self
,
name
):
for
param_name
,
weight_name
,
shard_id
,
num_shard
in
self
.
stacked_params_mapping
:
if
weight_name
in
name
:
return
(
name
.
replace
(
weight_name
,
param_name
)[:
-
len
(
".weight"
)],
num_shard
,
)
return
name
[:
-
len
(
".weight"
)],
1
def
get_num_params
(
self
):
params_dict
=
dict
(
self
.
named_parameters
())
return
len
(
params_dict
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
buffer
in
self
.
named_buffers
():
if
name
.
endswith
(
".beta"
)
or
name
.
endswith
(
".eps"
):
params_dict
[
name
]
=
buffer
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
# Handle FP8 kv-scale remapping
if
"scale"
in
name
:
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip loading kv_scale from ckpts towards new design.
if
name
.
endswith
(
".kv_scale"
)
and
name
not
in
params_dict
:
continue
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
del
self
.
lm_head
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
def
get_embed
(
self
):
return
self
.
model
.
embed_tokens
.
weight
def
set_embed
(
self
,
embed
):
# NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
if
(
hasattr
(
self
.
config
,
"target_hidden_size"
)
and
self
.
config
.
target_hidden_size
!=
self
.
config
.
hidden_size
):
return
del
self
.
model
.
embed_tokens
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
if
not
self
.
pp_group
.
is_last_rank
:
return
if
layer_ids
is
None
:
self
.
capture_aux_hidden_states
=
True
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
]
else
:
self
.
capture_aux_hidden_states
=
True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
EntryClass
=
[
ApertusForCausalLM
]
test/srt/models/test_generation_models.py
View file @
46d8fb1c
...
@@ -90,6 +90,11 @@ ALL_MODELS = [
...
@@ -90,6 +90,11 @@ ALL_MODELS = [
trust_remote_code
=
True
,
trust_remote_code
=
True
,
skip_long_prompt
=
True
,
skip_long_prompt
=
True
,
),
),
ModelCase
(
"swiss-ai/Apertus-8B"
,
trust_remote_code
=
True
,
skip_long_prompt
=
True
,
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment