Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ab87f852
Unverified
Commit
ab87f852
authored
Feb 26, 2026
by
Jiangyun Zhu
Committed by
GitHub
Feb 26, 2026
Browse files
[Model] Ring 2.5 (#35102)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
3827c8c5
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1407 additions
and
70 deletions
+1407
-70
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/registry.py
tests/models/registry.py
+3
-0
vllm/model_executor/layers/fla/ops/layernorm_guard.py
vllm/model_executor/layers/fla/ops/layernorm_guard.py
+30
-5
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+1
-0
vllm/model_executor/layers/mamba/linear_attn.py
vllm/model_executor/layers/mamba/linear_attn.py
+124
-65
vllm/model_executor/models/bailing_moe_linear.py
vllm/model_executor/models/bailing_moe_linear.py
+1246
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/model_arch_config_convertor.py
vllm/transformers_utils/model_arch_config_convertor.py
+1
-0
No files found.
docs/models/supported_models.md
View file @
ab87f852
...
@@ -372,6 +372,7 @@ th {
...
@@ -372,6 +372,7 @@ th {
|
`BaiChuanForCausalLM`
| Baichuan2, Baichuan |
`baichuan-inc/Baichuan2-13B-Chat`
,
`baichuan-inc/Baichuan-7B`
, etc. | ✅︎ | ✅︎ |
|
`BaiChuanForCausalLM`
| Baichuan2, Baichuan |
`baichuan-inc/Baichuan2-13B-Chat`
,
`baichuan-inc/Baichuan-7B`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeForCausalLM`
| Ling |
`inclusionAI/Ling-lite-1.5`
,
`inclusionAI/Ling-plus`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeForCausalLM`
| Ling |
`inclusionAI/Ling-lite-1.5`
,
`inclusionAI/Ling-plus`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeV2ForCausalLM`
| Ling |
`inclusionAI/Ling-mini-2.0`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeV2ForCausalLM`
| Ling |
`inclusionAI/Ling-mini-2.0`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeV2_5ForCausalLM`
| Ling |
`inclusionAI/Ling-2.5-1T`
,
`inclusionAI/Ring-2.5-1T`
| | ✅︎ |
|
`BambaForCausalLM`
| Bamba |
`ibm-ai-platform/Bamba-9B-fp8`
,
`ibm-ai-platform/Bamba-9B`
| ✅︎ | ✅︎ |
|
`BambaForCausalLM`
| Bamba |
`ibm-ai-platform/Bamba-9B-fp8`
,
`ibm-ai-platform/Bamba-9B`
| ✅︎ | ✅︎ |
|
`BloomForCausalLM`
| BLOOM, BLOOMZ, BLOOMChat |
`bigscience/bloom`
,
`bigscience/bloomz`
, etc. | | ✅︎ |
|
`BloomForCausalLM`
| BLOOM, BLOOMZ, BLOOMChat |
`bigscience/bloom`
,
`bigscience/bloomz`
, etc. | | ✅︎ |
|
`ChatGLMModel`
,
`ChatGLMForConditionalGeneration`
| ChatGLM |
`zai-org/chatglm2-6b`
,
`zai-org/chatglm3-6b`
,
`thu-coai/ShieldLM-6B-chatglm3`
, etc. | ✅︎ | ✅︎ |
|
`ChatGLMModel`
,
`ChatGLMForConditionalGeneration`
| ChatGLM |
`zai-org/chatglm2-6b`
,
`zai-org/chatglm3-6b`
,
`thu-coai/ShieldLM-6B-chatglm3`
, etc. | ✅︎ | ✅︎ |
...
...
tests/models/registry.py
View file @
ab87f852
...
@@ -206,6 +206,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -206,6 +206,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"BailingMoeV2ForCausalLM"
:
_HfExamplesInfo
(
"BailingMoeV2ForCausalLM"
:
_HfExamplesInfo
(
"inclusionAI/Ling-mini-2.0"
,
trust_remote_code
=
True
"inclusionAI/Ling-mini-2.0"
,
trust_remote_code
=
True
),
),
"BailingMoeV2_5ForCausalLM"
:
_HfExamplesInfo
(
"inclusionAI/Ring-2.5-1T"
,
trust_remote_code
=
True
),
"BambaForCausalLM"
:
_HfExamplesInfo
(
"BambaForCausalLM"
:
_HfExamplesInfo
(
"ibm-ai-platform/Bamba-9B-v1"
,
"ibm-ai-platform/Bamba-9B-v1"
,
extras
=
{
"tiny"
:
"hmellor/tiny-random-BambaForCausalLM"
},
extras
=
{
"tiny"
:
"hmellor/tiny-random-BambaForCausalLM"
},
...
...
vllm/model_executor/layers/fla/ops/layernorm_guard.py
View file @
ab87f852
...
@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
...
@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
HAS_Z
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
NORM_BEFORE_GATE
:
tl
.
constexpr
,
NORM_BEFORE_GATE
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
):
# Map the program id to the starting row of X and Y it should compute.
# Map the program id to the starting row of X and Y it should compute.
row_start
=
tl
.
program_id
(
0
)
*
ROWS_PER_BLOCK
row_start
=
tl
.
program_id
(
0
)
*
ROWS_PER_BLOCK
...
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
...
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
if
HAS_Z
and
not
NORM_BEFORE_GATE
:
if
HAS_Z
and
not
NORM_BEFORE_GATE
:
Z_base
=
Z
+
rows
[:,
None
]
*
stride_z_row
+
col_offsets
Z_base
=
Z
+
rows
[:,
None
]
*
stride_z_row
+
col_offsets
z
=
tl
.
load
(
Z_base
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
z
=
tl
.
load
(
Z_base
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
ACTIVATION
==
"swish"
or
ACTIVATION
==
"silu"
:
x
*=
z
*
tl
.
sigmoid
(
z
)
x
*=
z
*
tl
.
sigmoid
(
z
)
elif
ACTIVATION
==
"sigmoid"
:
x
*=
tl
.
sigmoid
(
z
)
# Compute mean and variance per row (reduce along axis 1)
# Compute mean and variance per row (reduce along axis 1)
if
not
IS_RMS_NORM
:
if
not
IS_RMS_NORM
:
...
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
...
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
if
HAS_Z
and
NORM_BEFORE_GATE
:
if
HAS_Z
and
NORM_BEFORE_GATE
:
Z_base
=
Z
+
rows
[:,
None
]
*
stride_z_row
+
col_offsets
Z_base
=
Z
+
rows
[:,
None
]
*
stride_z_row
+
col_offsets
z
=
tl
.
load
(
Z_base
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
z
=
tl
.
load
(
Z_base
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
ACTIVATION
==
"swish"
or
ACTIVATION
==
"silu"
:
y
*=
z
*
tl
.
sigmoid
(
z
)
y
*=
z
*
tl
.
sigmoid
(
z
)
elif
ACTIVATION
==
"sigmoid"
:
y
*=
tl
.
sigmoid
(
z
)
# Write output
# Write output
tl
.
store
(
Y_base
,
y
,
mask
=
mask
)
tl
.
store
(
Y_base
,
y
,
mask
=
mask
)
...
@@ -178,6 +185,7 @@ def layer_norm_fwd(
...
@@ -178,6 +185,7 @@ def layer_norm_fwd(
group_size
:
int
=
None
,
group_size
:
int
=
None
,
norm_before_gate
:
bool
=
True
,
norm_before_gate
:
bool
=
True
,
is_rms_norm
:
bool
=
False
,
is_rms_norm
:
bool
=
False
,
activation
:
str
=
"swish"
,
):
):
M
,
N
=
x
.
shape
M
,
N
=
x
.
shape
if
group_size
is
None
:
if
group_size
is
None
:
...
@@ -232,9 +240,12 @@ def layer_norm_fwd(
...
@@ -232,9 +240,12 @@ def layer_norm_fwd(
eps
,
eps
,
BLOCK_N
=
BLOCK_N
,
BLOCK_N
=
BLOCK_N
,
ROWS_PER_BLOCK
=
rows_per_block
,
ROWS_PER_BLOCK
=
rows_per_block
,
HAS_BIAS
=
bias
is
not
None
,
HAS_Z
=
z
is
not
None
,
NORM_BEFORE_GATE
=
norm_before_gate
,
NORM_BEFORE_GATE
=
norm_before_gate
,
IS_RMS_NORM
=
is_rms_norm
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
ACTIVATION
=
activation
,
)
)
return
out
,
mean
,
rstd
return
out
,
mean
,
rstd
...
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
group_size
=
None
,
group_size
=
None
,
norm_before_gate
=
True
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
):
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
...
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
group_size
=
group_size
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
is_rms_norm
=
is_rms_norm
,
activation
=
activation
,
)
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
,
mean
,
rstd
,
z
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
,
mean
,
rstd
,
z
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
x_shape_og
=
x_shape_og
...
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
ctx
.
group_size
=
group_size
ctx
.
group_size
=
group_size
ctx
.
norm_before_gate
=
norm_before_gate
ctx
.
norm_before_gate
=
norm_before_gate
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
activation
=
activation
return
y
.
reshape
(
x_shape_og
)
return
y
.
reshape
(
x_shape_og
)
...
@@ -296,17 +310,25 @@ def layernorm_fn(
...
@@ -296,17 +310,25 @@ def layernorm_fn(
group_size
=
None
,
group_size
=
None
,
norm_before_gate
=
True
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
):
):
return
LayerNormFn
.
apply
(
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
,
activation
)
)
def
rmsnorm_fn
(
def
rmsnorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
activation
:
str
=
"swish"
,
):
):
return
LayerNormFn
.
apply
(
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
,
activation
)
)
...
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
...
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
norm_before_gate
:
bool
=
False
,
norm_before_gate
:
bool
=
False
,
device
:
torch
.
device
|
None
=
None
,
device
:
torch
.
device
|
None
=
None
,
dtype
:
torch
.
dtype
|
None
=
None
,
dtype
:
torch
.
dtype
|
None
=
None
,
activation
:
str
=
"swish"
,
):
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
...
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
...
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
eps
=
eps
self
.
eps
=
eps
self
.
activation
=
activation
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
self
.
group_size
=
group_size
self
.
group_size
=
group_size
...
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
...
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
eps
=
self
.
eps
,
eps
=
self
.
eps
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
norm_before_gate
=
self
.
norm_before_gate
,
norm_before_gate
=
self
.
norm_before_gate
,
activation
=
self
.
activation
,
)
)
vllm/model_executor/layers/layernorm.py
View file @
ab87f852
...
@@ -592,6 +592,7 @@ class RMSNormGated(CustomOp):
...
@@ -592,6 +592,7 @@ class RMSNormGated(CustomOp):
eps
=
self
.
eps
,
eps
=
self
.
eps
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
norm_before_gate
=
self
.
norm_before_gate
,
norm_before_gate
=
self
.
norm_before_gate
,
activation
=
self
.
activation
,
)
)
...
...
vllm/model_executor/layers/mamba/linear_attn.py
View file @
ab87f852
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
math
from
collections.abc
import
Callable
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
...
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
self
.
weight
.
weight_loader
=
self
.
weight_loader
self
.
weight
.
weight_loader
=
self
.
weight_loader
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
return
@
staticmethod
@
staticmethod
def
weight_loader
(
def
weight_loader
(
...
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
...
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
shard_size
=
loaded_weight
.
shape
[
0
]
//
tp_world
shard_size
=
loaded_weight
.
shape
[
0
]
//
tp_world
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
param
.
data
.
copy_
(
loaded_weight
[
shard
])
param
.
data
.
copy_
(
loaded_weight
[
shard
])
return
def
_forward
(
def
_forward
(
self
,
self
,
...
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
...
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
return
q
,
k
return
q
,
k
def
clear_linear_attention_cache_for_new_sequences
(
kv_cache
:
torch
.
Tensor
,
state_indices_tensor
:
torch
.
Tensor
,
attn_metadata
:
LinearAttentionMetadata
,
)
->
None
:
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
if
num_prefills
<=
0
:
return
num_decode_tokens
=
getattr
(
attn_metadata
,
"num_decode_tokens"
,
0
)
for
prefill_idx
in
range
(
num_prefills
):
q_start
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
]
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
+
1
]
query_len
=
q_end
-
q_start
context_len
=
(
attn_metadata
.
seq_lens
[
num_decode_tokens
+
prefill_idx
]
-
query_len
)
if
context_len
==
0
:
block_to_clear
=
state_indices_tensor
[
num_decode_tokens
+
prefill_idx
]
kv_cache
[
block_to_clear
,
...]
=
0
def
linear_attention_decode
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slope_rate
:
torch
.
Tensor
,
state_indices_tensor
:
torch
.
Tensor
,
q_start
:
int
=
0
,
q_end
:
int
|
None
=
None
,
slot_start
:
int
=
0
,
slot_end
:
int
|
None
=
None
,
block_size
:
int
=
32
,
)
->
torch
.
Tensor
:
q
=
q
[
q_start
:
q_end
].
unsqueeze
(
2
).
contiguous
()
k
=
k
[
q_start
:
q_end
].
unsqueeze
(
2
).
contiguous
()
v
=
v
[
q_start
:
q_end
].
unsqueeze
(
2
).
contiguous
()
slot_id
=
state_indices_tensor
[
slot_start
:
slot_end
]
return
linear_decode_forward_triton
(
q
,
k
,
v
,
kv_cache
,
slope_rate
,
slot_id
,
block_size
)
def
linear_attention_prefill_and_mix
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
state_indices_tensor
:
torch
.
Tensor
,
attn_metadata
:
LinearAttentionMetadata
,
slope_rate
:
torch
.
Tensor
,
block_size
:
int
,
decode_fn
:
Callable
[...,
torch
.
Tensor
],
prefix_fn
:
Callable
[...,
torch
.
Tensor
],
layer_idx
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
hidden
=
[]
for
_prefill_idx
in
range
(
getattr
(
attn_metadata
,
"num_prefills"
,
0
)):
if
_prefill_idx
>=
len
(
attn_metadata
.
query_start_loc
):
break
if
_prefill_idx
>=
len
(
state_indices_tensor
):
break
offset
=
attn_metadata
.
num_decode_tokens
_start
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
]
_end
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
+
1
]
slot_id
=
state_indices_tensor
[
offset
+
_prefill_idx
]
qs
=
q
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
ks
=
k
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
vs
=
v
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
slice_layer_cache
=
kv_cache
[
slot_id
,
...]
out_slice
=
prefix_fn
(
qs
,
ks
,
vs
,
slice_layer_cache
,
slope_rate
,
block_size
,
layer_idx
=
layer_idx
,
)
hidden
.
append
(
out_slice
.
contiguous
())
if
attn_metadata
.
num_decode_tokens
>
0
:
hidden_decode
=
decode_fn
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
hidden
.
insert
(
0
,
hidden_decode
)
if
not
hidden
:
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
hidden
=
torch
.
concat
(
hidden
,
dim
=
0
).
contiguous
()
return
hidden
class
MiniMaxText01LinearKernel
:
class
MiniMaxText01LinearKernel
:
@
staticmethod
@
staticmethod
def
jit_linear_forward_prefix
(
def
jit_linear_forward_prefix
(
...
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def
_prefill_and_mix_infer
(
def
_prefill_and_mix_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
):
hidden
=
[]
return
linear_attention_prefill_and_mix
(
for
_prefill_idx
in
range
(
getattr
(
attn_metadata
,
"num_prefills"
,
0
)):
q
=
q
,
if
_prefill_idx
>=
len
(
attn_metadata
.
query_start_loc
):
k
=
k
,
break
v
=
v
,
if
_prefill_idx
>=
len
(
state_indices_tensor
):
kv_cache
=
kv_cache
,
break
state_indices_tensor
=
state_indices_tensor
,
offset
=
attn_metadata
.
num_decode_tokens
attn_metadata
=
attn_metadata
,
_start
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
]
slope_rate
=
self
.
tp_slope
,
_end
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
+
1
]
block_size
=
self
.
BLOCK
,
slot_id
=
state_indices_tensor
[
offset
+
_prefill_idx
]
decode_fn
=
self
.
_decode_infer
,
qs
=
q
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
prefix_fn
=
MiniMaxText01LinearKernel
.
jit_linear_forward_prefix
,
ks
=
k
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
vs
=
v
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
slice_layer_cache
=
kv_cache
[
slot_id
,
...]
out_slice
=
MiniMaxText01LinearKernel
.
jit_linear_forward_prefix
(
qs
,
ks
,
vs
,
slice_layer_cache
,
self
.
tp_slope
,
self
.
BLOCK
,
layer_idx
=
self
.
layer_idx
,
layer_idx
=
self
.
layer_idx
,
)
)
hidden
.
append
(
out_slice
.
contiguous
())
if
attn_metadata
.
num_decode_tokens
>
0
:
hidden_decode
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
hidden
.
insert
(
0
,
hidden_decode
)
if
not
hidden
:
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
hidden
=
torch
.
concat
(
hidden
,
dim
=
0
).
contiguous
()
return
hidden
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
q
=
q
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
hidden
=
linear_attention_decode
(
k
=
k
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
q
,
v
=
v
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
k
,
slot_id
=
state_indices_tensor
[:
attn_metadata
.
num_decodes
]
v
,
hidden
=
linear_decode_forward_triton
(
kv_cache
,
q
,
k
,
v
,
kv_cache
,
self
.
tp_slope
,
slot_id
,
32
self
.
tp_slope
,
state_indices_tensor
,
q_start
=
0
,
q_end
=
attn_metadata
.
num_decode_tokens
,
slot_start
=
0
,
slot_end
=
attn_metadata
.
num_decodes
,
block_size
=
32
,
)
)
return
hidden
return
hidden
...
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
if
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
clear_linear_attention_cache_for_new_sequences
(
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
kv_cache
,
state_indices_tensor
,
attn_metadata
if
num_prefills
>
0
:
num_decode_tokens
=
getattr
(
attn_metadata
,
"num_decode_tokens"
,
0
)
for
prefill_idx
in
range
(
num_prefills
):
q_start
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
]
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
+
1
]
query_len
=
q_end
-
q_start
context_len
=
(
attn_metadata
.
seq_lens
[
num_decode_tokens
+
prefill_idx
]
-
query_len
)
)
if
context_len
==
0
:
block_to_clear
=
state_indices_tensor
[
num_decode_tokens
+
prefill_idx
]
kv_cache
[
block_to_clear
,
...]
=
0
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
if
attn_metadata
is
None
:
if
attn_metadata
is
None
:
...
...
vllm/model_executor/models/bailing_moe_linear.py
0 → 100644
View file @
ab87f852
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
collections.abc
import
Iterable
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fla.ops.layernorm_guard
import
(
RMSNormGated
,
layernorm_fn
,
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
SharedFusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.linear_attn
import
(
MiniMaxText01LinearAttention
,
MiniMaxText01LinearKernel
,
MiniMaxText01RMSNormTP
,
clear_linear_attention_cache_for_new_sequences
,
linear_attention_decode
,
linear_attention_prefill_and_mix
,
)
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFuncCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
)
from
vllm.model_executor.layers.mla
import
MLAModules
,
MultiHeadLatentAttentionWrapper
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.model_executor.models.bailing_moe
import
BailingMLP
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.linear_attn
import
LinearAttentionMetadata
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
def
is_linear_layer
(
layer_idx
,
layer_group_size
):
if
layer_idx
is
None
:
return
False
if
layer_group_size
>
0
:
return
(
layer_idx
+
1
)
%
layer_group_size
!=
0
else
:
return
False
def
_build_rope_parameters
(
config
:
PretrainedConfig
)
->
dict
|
None
:
rope_parameters
=
copy
.
deepcopy
(
getattr
(
config
,
"rope_parameters"
,
None
))
or
{}
if
"rope_theta"
not
in
rope_parameters
and
hasattr
(
config
,
"rope_theta"
):
rope_parameters
[
"rope_theta"
]
=
config
.
rope_theta
if
"partial_rotary_factor"
not
in
rope_parameters
and
hasattr
(
config
,
"partial_rotary_factor"
):
rope_parameters
[
"partial_rotary_factor"
]
=
config
.
partial_rotary_factor
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
isinstance
(
rope_scaling
,
dict
):
rope_scaling
=
copy
.
deepcopy
(
rope_scaling
)
if
"type"
in
rope_scaling
and
"rope_type"
not
in
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
rope_scaling
.
pop
(
"type"
)
rope_parameters
.
update
(
rope_scaling
)
return
rope_parameters
or
None
class
BailingMoeV25MLAAttention
(
nn
.
Module
):
"""
MLA Attention for BailingMoeV2.5 full attention layers.
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
layer_id
:
int
=
0
,
prefix
:
str
=
"attention"
,
cache_config
:
CacheConfig
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
layer_id
=
layer_id
self
.
prefix
=
prefix
# MLA dimensions
self
.
qk_nope_head_dim
=
getattr
(
config
,
"qk_nope_head_dim"
,
128
)
self
.
qk_rope_head_dim
=
getattr
(
config
,
"qk_rope_head_dim"
,
64
)
self
.
qk_head_dim
=
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
self
.
v_head_dim
=
getattr
(
config
,
"v_head_dim"
,
128
)
# LoRA ranks
self
.
q_lora_rank
=
getattr
(
config
,
"q_lora_rank"
,
None
)
self
.
kv_lora_rank
=
getattr
(
config
,
"kv_lora_rank"
,
512
)
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
self
.
num_heads
//
tp_size
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
# KV projections
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
,
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.kv_b_proj"
,
)
# Output projection
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
if
self
.
q_lora_rank
is
not
None
:
# Use fused_qkv_a_proj when q_lora_rank is set
self
.
fused_qkv_a_proj
=
MergedColumnParallelLinear
(
self
.
hidden_size
,
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fused_qkv_a_proj"
,
disable_tp
=
True
,
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
,
)
self
.
q_b_proj
=
ColumnParallelLinear
(
self
.
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
,
)
self
.
q_proj
=
None
self
.
kv_a_proj_with_mqa
=
None
else
:
# Direct projections when no q_lora_rank
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.kv_a_proj_with_mqa"
,
)
self
.
fused_qkv_a_proj
=
None
self
.
q_a_layernorm
=
None
self
.
q_b_proj
=
None
rope_parameters
=
_build_rope_parameters
(
config
)
max_position
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
rotary_emb
=
get_rope
(
head_size
=
self
.
qk_rope_head_dim
,
max_position
=
max_position
,
is_neox_style
=
False
,
rope_parameters
=
rope_parameters
or
None
,
dtype
=
torch
.
float32
,
)
# Build MLAModules for MultiHeadLatentAttentionWrapper
mla_modules
=
MLAModules
(
kv_a_layernorm
=
self
.
kv_a_layernorm
,
kv_b_proj
=
self
.
kv_b_proj
,
rotary_emb
=
self
.
rotary_emb
,
o_proj
=
self
.
o_proj
,
fused_qkv_a_proj
=
self
.
fused_qkv_a_proj
,
kv_a_proj_with_mqa
=
self
.
kv_a_proj_with_mqa
,
q_a_layernorm
=
self
.
q_a_layernorm
,
q_b_proj
=
self
.
q_b_proj
,
q_proj
=
self
.
q_proj
,
indexer
=
None
,
is_sparse
=
False
,
topk_indices_buffer
=
None
,
)
self
.
mla_attn
=
MultiHeadLatentAttentionWrapper
(
self
.
hidden_size
,
self
.
num_local_heads
,
self
.
scaling
,
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
,
self
.
v_head_dim
,
self
.
q_lora_rank
,
self
.
kv_lora_rank
,
mla_modules
,
cache_config
,
quant_config
,
prefix
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Forward pass for MLA attention."""
return
self
.
mla_attn
(
positions
,
hidden_states
)
class
BailingMoEGate
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
params_dtype
:
torch
.
dtype
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
(
config
.
num_experts
,
config
.
hidden_size
),
dtype
=
self
.
params_dtype
,
),
)
if
getattr
(
config
,
"moe_router_enable_expert_bias"
,
False
):
self
.
expert_bias
=
nn
.
Parameter
(
torch
.
empty
((
config
.
num_experts
,),
dtype
=
torch
.
float32
),
)
else
:
self
.
expert_bias
=
None
def
forward
(
self
,
hidden_states
):
logits
=
F
.
linear
(
hidden_states
.
to
(
self
.
weight
.
dtype
),
self
.
weight
,
None
).
to
(
hidden_states
.
dtype
)
return
logits
class
BailingMoeV25
(
nn
.
Module
):
"""Bailing MoE v2.5 - standalone implementation for linear attention model."""
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
layer_id
:
int
=
0
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
num_experts
=
config
.
num_experts
self
.
top_k
=
config
.
num_experts_per_tok
norm_topk_prob
=
getattr
(
config
,
"norm_topk_prob"
,
None
)
# Ring-2.5 reference implementations normalize routing weights by default.
self
.
norm_expert_prob
=
True
if
norm_topk_prob
is
None
else
bool
(
norm_topk_prob
)
self
.
hidden_size
=
config
.
hidden_size
self
.
quant_config
=
quant_config
self
.
num_shared_experts
=
config
.
num_shared_experts
self
.
score_function
=
getattr
(
config
,
"score_function"
,
None
)
self
.
n_group
=
getattr
(
config
,
"n_group"
,
None
)
self
.
topk_group
=
getattr
(
config
,
"topk_group"
,
None
)
self
.
use_grouped_topk
=
self
.
n_group
is
not
None
and
self
.
topk_group
is
not
None
self
.
routed_scaling_factor
=
getattr
(
config
,
"routed_scaling_factor"
,
1.0
)
router_dtype
=
getattr
(
config
,
"router_dtype"
,
None
)
if
router_dtype
is
None
or
router_dtype
==
"fp32"
:
self
.
router_dtype
=
torch
.
float32
else
:
self
.
router_dtype
=
torch
.
bfloat16
# Gate for routing
self
.
gate
=
BailingMoEGate
(
config
=
config
,
params_dtype
=
self
.
router_dtype
,
prefix
=
f
"
{
prefix
}
.gate"
,
)
correction_bias
=
(
self
.
gate
.
expert_bias
if
self
.
gate
.
expert_bias
is
not
None
else
None
)
if
self
.
score_function
is
not
None
:
assert
(
self
.
score_function
==
"softmax"
and
correction_bias
is
None
)
or
(
self
.
score_function
==
"sigmoid"
and
correction_bias
is
not
None
),
(
"score_function and correction_bias should be "
"(softmax, None) or (sigmoid, not None)"
)
# Shared experts (using BailingMLP)
if
self
.
num_shared_experts
>
0
:
if
hasattr
(
config
,
"moe_shared_expert_intermediate_size"
):
intermediate_size
=
config
.
moe_shared_expert_intermediate_size
else
:
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
*=
config
.
num_shared_experts
self
.
shared_experts
=
BailingMLP
(
intermediate_size
=
intermediate_size
,
config
=
config
,
quant_config
=
quant_config
,
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
else
:
self
.
shared_experts
=
None
# Routed experts using SharedFusedMoE
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
shared_experts
,
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
self
.
norm_expert_prob
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
self
.
score_function
,
e_score_correction_bias
=
correction_bias
,
num_expert_group
=
self
.
n_group
,
topk_group
=
self
.
topk_group
,
use_grouped_topk
=
self
.
use_grouped_topk
,
router_logits_dtype
=
self
.
router_dtype
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
# Ensure contiguous token-major layout before router/projections.
hidden_states
=
hidden_states
.
contiguous
().
view
(
-
1
,
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
.
to
(
self
.
router_dtype
))
router_logits
=
router_logits
.
to
(
hidden_states
.
dtype
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
# Handle tuple return from SharedFusedMoE
if
self
.
shared_experts
is
not
None
:
shared_output
,
final_hidden_states
=
final_hidden_states
else
:
shared_output
=
None
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
final_hidden_states
=
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
BailingRMSNormTP
=
MiniMaxText01RMSNormTP
class
BailingGroupRMSNormGate
(
RMSNormGated
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
None
,
dtype
=
None
,
):
super
().
__init__
(
hidden_size
,
eps
=
eps
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
device
=
device
,
dtype
=
dtype
,
activation
=
"sigmoid"
,
)
# Add custom weight loader for TP sharding
self
.
weight
.
weight_loader
=
self
.
_weight_loader
@
staticmethod
def
_weight_loader
(
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Load weight with TP sharding."""
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
loaded_weight
.
shape
[
0
]
//
tp_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
param
.
data
.
copy_
(
loaded_weight
[
shard
].
contiguous
())
class
BailingMoELinearAttention
(
nn
.
Module
,
MambaBase
):
"""
Bailing MoE Linear Attention implementation using minimax backend.
This implements the linear attention mechanism from sglang, adapted for vLLM's
v1 engine with MambaBase interface support.
"""
@
property
def
mamba_type
(
self
)
->
str
:
return
"linear_attention"
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...],
...]:
"""Return state shape for linear attention cache.
Must match the calculation in get_mamba_state_shape_from_config.
"""
return
MambaStateShapeCalculator
.
linear_attention_state_shape
(
num_heads
=
self
.
total_num_heads
,
tp_size
=
self
.
tp_size
,
head_dim
=
self
.
head_dim
,
)
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
,
...]:
"""Return state dtype for linear attention cache.
Must match the calculation in get_mamba_state_dtype_from_config.
"""
return
MambaStateDtypeCalculator
.
linear_attention_state_dtype
(
self
.
model_config
.
dtype
,
self
.
cache_config
.
mamba_cache_dtype
,
)
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
layer_id
:
int
=
0
,
prefix
:
str
=
"linear_attn"
,
model_config
:
ModelConfig
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
):
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
config
.
hidden_size
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
total_kv_heads
=
config
.
num_attention_heads
# MHA
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
prefix
=
prefix
self
.
head_dim
=
(
config
.
head_dim
if
hasattr
(
config
,
"head_dim"
)
else
config
.
hidden_size
//
self
.
total_num_heads
)
self
.
hidden_inner_size
=
self
.
head_dim
*
self
.
total_num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
assert
self
.
total_num_heads
%
self
.
tp_size
==
0
self
.
tp_heads
=
self
.
total_num_heads
//
self
.
tp_size
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rope_theta
=
getattr
(
config
,
"rope_theta"
,
600000
)
self
.
tp_kv_heads
=
self
.
total_kv_heads
//
self
.
tp_size
self
.
q_size_per_rank
=
self
.
head_dim
*
self
.
tp_heads
self
.
kv_size_per_rank
=
self
.
head_dim
*
self
.
tp_kv_heads
self
.
use_qk_norm
=
getattr
(
config
,
"use_qk_norm"
,
False
)
self
.
linear_backend
=
"minimax"
self
.
linear_scale
=
self
.
linear_backend
==
"minimax"
self
.
linear_rope
=
getattr
(
config
,
"linear_rope"
,
True
)
if
hasattr
(
config
,
"use_linear_silu"
):
self
.
linear_silu
=
config
.
use_linear_silu
elif
hasattr
(
config
,
"linear_silu"
):
self
.
linear_silu
=
config
.
linear_silu
else
:
self
.
linear_silu
=
False
# Block size for lightning attention
self
.
BLOCK
=
getattr
(
config
,
"block"
,
256
)
self
.
query_key_value
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_heads
,
# MHA: kv_heads = num_heads
bias
=
(
config
.
use_bias
or
config
.
use_qkv_bias
),
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.query_key_value"
,
)
if
self
.
use_qk_norm
:
self
.
query_layernorm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
key_layernorm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
g_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_inner_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.g_proj"
,
)
self
.
dense
=
RowParallelLinear
(
self
.
hidden_inner_size
,
self
.
hidden_size
,
bias
=
config
.
use_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.dense"
,
reduce_results
=
True
,
)
self
.
group_norm_size
=
getattr
(
config
,
"group_norm_size"
,
1
)
self
.
rms_norm_eps
=
float
(
getattr
(
config
,
"rms_norm_eps"
,
1e-5
))
assert
self
.
tp_size
<=
self
.
group_norm_size
,
(
"tp_size must be <= group_norm_size for local rms norm"
)
assert
self
.
group_norm_size
%
self
.
tp_size
==
0
,
(
"group_norm_size must be divisible by tp_size"
)
# When group_norm_size == 1, group_size equals hidden_size // tp_size
self
.
g_norm
=
BailingGroupRMSNormGate
(
hidden_size
=
self
.
hidden_inner_size
//
self
.
tp_size
,
eps
=
self
.
rms_norm_eps
,
group_size
=
(
self
.
hidden_inner_size
//
self
.
group_norm_size
if
self
.
group_norm_size
>
1
else
self
.
hidden_inner_size
//
self
.
tp_size
),
)
# use fp32 rotary embedding
rope_parameters
=
_build_rope_parameters
(
config
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
is_neox_style
=
True
,
dtype
=
torch
.
float32
,
rope_parameters
=
rope_parameters
or
None
,
)
# Build slope tensor for linear attention decay
num_hidden_layers
=
config
.
num_hidden_layers
slope_rate
=
MiniMaxText01LinearAttention
.
_build_slope_tensor
(
self
.
total_num_heads
)
if
num_hidden_layers
<=
1
:
self
.
slope_rate
=
slope_rate
*
(
1
+
1e-5
)
else
:
self
.
slope_rate
=
slope_rate
*
(
1
-
layer_id
/
(
num_hidden_layers
-
1
)
+
1e-5
)
self
.
tp_slope
=
self
.
slope_rate
[
self
.
tp_rank
*
self
.
tp_heads
:
(
self
.
tp_rank
+
1
)
*
self
.
tp_heads
].
contiguous
()
# Register for compilation
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
@
staticmethod
def
weight_direct_load
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Load weight for linear attention layers.
For FP8 quantized parameters, we need to use the weight_loader if available,
as it handles special cases like tensor parallelism sharding.
"""
# Check if param has a weight_loader (for vLLM ModelWeightParameter)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
None
)
if
weight_loader
is
not
None
:
# Use the weight_loader which handles TP sharding and quantization
weight_loader
(
param
,
loaded_weight
)
else
:
# Fall back to direct copy for standard tensors
assert
param
.
size
()
==
loaded_weight
.
size
(),
(
f
"Shape mismatch:
{
param
.
shape
}
vs
{
loaded_weight
.
shape
}
"
)
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
None
:
"""Forward method called by torch.ops.vllm.linear_attention"""
torch
.
ops
.
vllm
.
linear_attention
(
hidden_states
,
output
,
positions
,
self
.
prefix
,
)
def
_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
None
:
"""Actual forward implementation."""
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
LinearAttentionMetadata
)
num_actual_tokens
=
(
attn_metadata
.
num_prefill_tokens
+
attn_metadata
.
num_decode_tokens
)
else
:
num_actual_tokens
=
hidden_states
.
shape
[
0
]
# QKV projection
qkv
,
_
=
self
.
query_key_value
(
hidden_states
[:
num_actual_tokens
])
# use rotary_emb support fp32
qkv
=
qkv
.
to
(
torch
.
float32
)
if
self
.
linear_silu
:
qkv
=
F
.
silu
(
qkv
)
# Split q, k, v
q
,
k
,
v
=
torch
.
split
(
qkv
,
[
self
.
q_size_per_rank
,
self
.
kv_size_per_rank
,
self
.
kv_size_per_rank
],
dim
=-
1
,
)
# Apply QK norm if needed
if
self
.
use_qk_norm
:
q
=
q
.
reshape
(
-
1
,
self
.
tp_heads
,
self
.
head_dim
)
k
=
k
.
reshape
(
-
1
,
self
.
tp_kv_heads
,
self
.
head_dim
)
q
=
layernorm_fn
(
q
,
self
.
query_layernorm
.
weight
.
data
,
bias
=
None
,
eps
=
self
.
rms_norm_eps
,
is_rms_norm
=
True
,
)
k
=
layernorm_fn
(
k
,
self
.
key_layernorm
.
weight
.
data
,
bias
=
None
,
eps
=
self
.
rms_norm_eps
,
is_rms_norm
=
True
,
)
q
=
q
.
reshape
(
-
1
,
self
.
q_size_per_rank
)
k
=
k
.
reshape
(
-
1
,
self
.
kv_size_per_rank
)
# Apply rotary embeddings
if
self
.
linear_rope
:
q
,
k
=
self
.
rotary_emb
(
positions
[:
num_actual_tokens
],
q
,
k
)
# Reshape to [batch, heads, seq_len, head_dim]
q
=
q
.
view
((
qkv
.
shape
[
0
],
self
.
tp_heads
,
self
.
head_dim
))
k
=
k
.
view
((
qkv
.
shape
[
0
],
self
.
tp_kv_heads
,
self
.
head_dim
))
v
=
v
.
view
((
qkv
.
shape
[
0
],
self
.
tp_kv_heads
,
self
.
head_dim
))
# Apply scaling if using minimax backend
if
self
.
linear_scale
:
q
=
q
*
self
.
scaling
# Get KV cache and state indices
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
clear_linear_attention_cache_for_new_sequences
(
kv_cache
,
state_indices_tensor
,
attn_metadata
)
# Compute attention
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
if
attn_metadata
is
None
:
hidden
=
torch
.
empty
(
(
q
.
shape
[
0
],
q
.
shape
[
1
]
*
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
else
:
if
not
decode_only
:
hidden
=
self
.
_prefill_and_mix_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
else
:
hidden
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
# Apply group norm and gate (matching SGLang behavior)
gate
,
_
=
self
.
g_proj
(
hidden_states
[:
num_actual_tokens
])
if
self
.
group_norm_size
>
1
:
hidden
=
self
.
g_norm
(
hidden
,
gate
)
else
:
hidden
=
self
.
g_norm
(
hidden
)
hidden
=
F
.
sigmoid
(
gate
)
*
hidden
hidden
=
hidden
.
to
(
hidden_states
.
dtype
)
# Output projection
dense_out
,
_
=
self
.
dense
(
hidden
)
output
[:
num_actual_tokens
]
=
dense_out
def
_prefill_and_mix_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
"""Handle prefill (mixed with decode if any)."""
return
linear_attention_prefill_and_mix
(
q
=
q
,
k
=
k
,
v
=
v
,
kv_cache
=
kv_cache
,
state_indices_tensor
=
state_indices_tensor
,
attn_metadata
=
attn_metadata
,
slope_rate
=
self
.
tp_slope
,
block_size
=
self
.
BLOCK
,
decode_fn
=
self
.
_decode_infer
,
prefix_fn
=
MiniMaxText01LinearKernel
.
jit_linear_forward_prefix
,
layer_idx
=
self
.
layer_id
,
)
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
"""Handle decode (single token per sequence)."""
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefills
=
attn_metadata
.
num_prefills
hidden
=
linear_attention_decode
(
q
,
k
,
v
,
kv_cache
,
self
.
tp_slope
,
state_indices_tensor
,
q_start
=
num_prefill_tokens
,
q_end
=
None
,
slot_start
=
num_prefills
,
slot_end
=
None
,
block_size
=
32
,
)
return
hidden
class
BailingMoeV25DecoderLayer
(
nn
.
Module
):
"""Decoder layer supporting both linear and full attention."""
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
layer_id
:
int
=
0
,
prefix
:
str
=
"layer"
,
model_config
:
ModelConfig
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
config
.
hidden_size
# Determine attention type (0 = linear, 1 = full)
self
.
attention_type
=
getattr
(
config
,
"attention_type"
,
1
)
if
self
.
attention_type
==
0
:
# Linear attention
self
.
self_attn
=
BailingMoELinearAttention
(
config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
model_config
=
model_config
,
cache_config
=
cache_config
,
)
else
:
# Full attention
self
.
self_attn
=
BailingMoeV25MLAAttention
(
config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
cache_config
=
cache_config
,
)
# MLP/MoE
is_moe_layer
=
config
.
num_experts
>
1
and
layer_id
>=
getattr
(
config
,
"first_k_dense_replace"
,
0
)
if
is_moe_layer
:
self
.
mlp
=
BailingMoeV25
(
config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
else
:
self
.
mlp
=
BailingMLP
(
intermediate_size
=
config
.
intermediate_size
,
config
=
config
,
quant_config
=
quant_config
,
reduce_results
=
True
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
# Layer norms
rms_norm_eps
=
float
(
getattr
(
config
,
"rms_norm_eps"
,
1e-5
))
self
.
input_layernorm
=
RMSNorm
(
self
.
hidden_size
,
eps
=
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
hidden_size
,
eps
=
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
# Input layernorm
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self attention
if
self
.
attention_type
==
0
:
# Linear attention uses output tensor
self_attention_output
=
torch
.
zeros_like
(
hidden_states
)
self
.
self_attn
(
hidden_states
=
hidden_states
,
output
=
self_attention_output
,
positions
=
positions
,
)
else
:
# Full attention
self_attention_output
=
self
.
self_attn
(
hidden_states
,
positions
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
self_attention_output
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
}
)
class
BailingMoeV25Model
(
nn
.
Module
):
"""Bailing MoE v2.5 Model with hybrid attention support."""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
model_config
=
vllm_config
.
model_config
quant_config
=
vllm_config
.
quant_config
cache_config
=
vllm_config
.
cache_config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_dim
=
config
.
hidden_size
# Determine layer types based on layer_group_size
self
.
layer_group_size
=
getattr
(
config
,
"layer_group_size"
,
1
)
self
.
num_layers
=
config
.
num_hidden_layers
# decoder_attention_types: 0 = linear, 1 = full
self
.
decoder_attention_types
=
[
0
if
is_linear_layer
(
i
,
self
.
layer_group_size
)
else
1
for
i
in
range
(
self
.
num_layers
)
]
# Embeddings
if
get_pp_group
().
is_first_rank
:
self
.
word_embeddings
=
VocabParallelEmbedding
(
self
.
vocab_size
,
self
.
embed_dim
,
org_num_embeddings
=
self
.
vocab_size
,
)
else
:
from
vllm.model_executor.models.utils
import
PPMissingLayer
self
.
word_embeddings
=
PPMissingLayer
()
# Layers
def
layer_fn
(
prefix
):
layer_idx
=
int
(
prefix
.
split
(
"."
)[
-
1
])
layer_config
=
copy
.
deepcopy
(
config
)
layer_config
.
attention_type
=
self
.
decoder_attention_types
[
layer_idx
]
return
BailingMoeV25DecoderLayer
(
config
=
layer_config
,
quant_config
=
quant_config
,
layer_id
=
layer_idx
,
prefix
=
prefix
,
model_config
=
model_config
,
cache_config
=
cache_config
,
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
num_layers
,
layer_fn
,
prefix
=
f
"
{
prefix
}
.layers"
)
# Final norm
norm_kwargs
=
{}
if
hasattr
(
config
,
"rms_norm_eps"
):
norm_kwargs
[
"eps"
]
=
config
.
rms_norm_eps
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
**
norm_kwargs
)
else
:
from
vllm.model_executor.models.utils
import
PPMissingLayer
self
.
norm
=
PPMissingLayer
()
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
word_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
forward_context
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
None
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
else
:
hidden_states
=
inputs_embeds
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
hidden_states
=
hidden_states
,
positions
=
positions
,
attn_metadata
=
attn_metadata
,
residual
=
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
else
:
if
residual
is
not
None
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
else
:
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
"""Get expert parameter mapping for MoE layers."""
return
FusedMoE
.
make_expert_params_mapping
(
self
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
num_experts
,
num_redundant_experts
=
0
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
"""Load checkpoint weights with simplified mapping."""
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
# Stacked parameter mappings (fused projections)
stacked_mappings
=
[
(
".fused_qkv_a_proj"
,
".q_a_proj"
,
0
),
(
".fused_qkv_a_proj"
,
".kv_a_proj_with_mqa"
,
1
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
# Expert parameter mappings from FusedMoE
expert_mappings
=
list
(
self
.
get_expert_mapping
())
def
load_param
(
name
:
str
,
tensor
:
torch
.
Tensor
,
shard_id
=
None
)
->
bool
:
"""Load a single parameter."""
if
name
not
in
params_dict
or
is_pp_missing_parameter
(
name
,
self
):
return
False
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
return
False
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
shard_id
is
None
:
weight_loader
(
param
,
tensor
)
elif
isinstance
(
shard_id
,
int
):
weight_loader
(
param
,
tensor
,
shard_id
)
else
:
# Expert param: (expert_id, shard_id)
weight_loader
(
param
,
tensor
,
name
,
expert_id
=
shard_id
[
0
],
shard_id
=
shard_id
[
1
]
)
loaded_params
.
add
(
name
)
return
True
def
normalize_name
(
name
:
str
)
->
str
|
None
:
"""Normalize checkpoint name to model parameter name."""
# Skip special weights
if
name
.
startswith
(
"model.mtp"
):
return
None
# Remove 'model.' prefix if present
# (e.g., 'model.layers.0...' -> 'layers.0...')
name
=
name
.
removeprefix
(
"model."
)
# Map attention.dense based on layer type
if
"attention.dense"
in
name
:
layer_idx
=
(
int
(
name
.
split
(
"layers."
)[
1
].
split
(
"."
)[
0
])
if
"layers."
in
name
else
0
)
attn_name
=
(
"self_attn.dense"
if
is_linear_layer
(
layer_idx
,
self
.
config
.
layer_group_size
)
else
"self_attn.o_proj"
)
name
=
name
.
replace
(
"attention.dense"
,
attn_name
)
# Standard mappings
name
=
name
.
replace
(
"attention."
,
"self_attn."
)
name
=
name
.
replace
(
"mlp.gate.e_score_correction_bias"
,
"mlp.gate.expert_bias"
)
return
maybe_remap_kv_scale_name
(
name
,
params_dict
)
for
orig_name
,
weight
in
weights
:
norm_name
=
normalize_name
(
orig_name
)
if
norm_name
is
None
:
continue
# Try stacked mappings
loaded
=
False
for
param_suf
,
weight_suf
,
shard_id
in
stacked_mappings
:
if
weight_suf
not
in
norm_name
:
continue
mapped
=
norm_name
.
replace
(
weight_suf
,
param_suf
).
replace
(
"attention."
,
"self_attn."
)
if
load_param
(
mapped
,
weight
,
shard_id
):
loaded
=
True
break
if
loaded
:
continue
# Handle expert weights
if
"mlp.experts"
in
norm_name
:
# Expert bias
if
(
"mlp.experts.e_score_correction_bias"
in
norm_name
or
"mlp.experts.expert_bias"
in
norm_name
):
alt
=
norm_name
.
replace
(
"mlp.experts.e_score_correction_bias"
,
"mlp.gate.expert_bias"
).
replace
(
"mlp.experts.expert_bias"
,
"mlp.gate.expert_bias"
)
if
load_param
(
alt
,
weight
)
or
load_param
(
norm_name
,
weight
):
continue
# Routed experts
for
param_name
,
weight_name
,
expert_id
,
shard_id
in
expert_mappings
:
if
weight_name
not
in
norm_name
:
continue
mapped
=
norm_name
.
replace
(
weight_name
,
param_name
)
if
load_param
(
mapped
,
weight
,
(
expert_id
,
shard_id
)):
break
continue
# General parameters
load_param
(
norm_name
,
weight
)
return
loaded_params
class
BailingMoeV25ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsHybrid
,
SupportsPP
):
"""Bailing MoE v2.5 For CausalLM."""
packed_modules_mapping
=
{
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
BailingMoeV25Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
)
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
else
:
self
.
lm_head
=
PPMissingLayer
()
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
(
{
"hidden_states"
:
torch
.
zeros
(
(
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
(
(
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
}
)
@
classmethod
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
VllmConfig
,
)
->
tuple
[
tuple
[
int
,
...],
...]:
"""Calculate shape for linear attention cache."""
config
=
vllm_config
.
model_config
.
hf_config
tp_size
=
vllm_config
.
parallel_config
.
tensor_parallel_size
head_dim
=
getattr
(
config
,
"head_dim"
,
config
.
hidden_size
//
config
.
num_attention_heads
)
# Return base state shape from linear attention (no padding)
return
MambaStateShapeCalculator
.
linear_attention_state_shape
(
num_heads
=
config
.
num_attention_heads
,
tp_size
=
tp_size
,
head_dim
=
head_dim
,
)
@
classmethod
def
get_mamba_state_dtype_from_config
(
cls
,
vllm_config
:
VllmConfig
,
)
->
tuple
[
torch
.
dtype
,
...]:
return
MambaStateDtypeCalculator
.
linear_attention_state_dtype
(
vllm_config
.
model_config
.
dtype
,
vllm_config
.
cache_config
.
mamba_cache_dtype
,
)
@
classmethod
def
get_mamba_state_copy_func
(
cls
)
->
tuple
:
return
MambaStateCopyFuncCalculator
.
linear_attention_state_copy_func
()
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
return
self
.
model
.
get_expert_mapping
()
vllm/model_executor/models/registry.py
View file @
ab87f852
...
@@ -81,6 +81,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -81,6 +81,7 @@ _TEXT_GENERATION_MODELS = {
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
"BailingMoeForCausalLM"
:
(
"bailing_moe"
,
"BailingMoeForCausalLM"
),
"BailingMoeForCausalLM"
:
(
"bailing_moe"
,
"BailingMoeForCausalLM"
),
"BailingMoeV2ForCausalLM"
:
(
"bailing_moe"
,
"BailingMoeV2ForCausalLM"
),
"BailingMoeV2ForCausalLM"
:
(
"bailing_moe"
,
"BailingMoeV2ForCausalLM"
),
"BailingMoeV2_5ForCausalLM"
:
(
"bailing_moe_linear"
,
"BailingMoeV25ForCausalLM"
),
"BambaForCausalLM"
:
(
"bamba"
,
"BambaForCausalLM"
),
"BambaForCausalLM"
:
(
"bamba"
,
"BambaForCausalLM"
),
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
...
...
vllm/transformers_utils/model_arch_config_convertor.py
View file @
ab87f852
...
@@ -245,6 +245,7 @@ class ModelArchConfigConvertorBase:
...
@@ -245,6 +245,7 @@ class ModelArchConfigConvertorBase:
"longcat_flash"
,
"longcat_flash"
,
"pangu_ultra_moe"
,
"pangu_ultra_moe"
,
"pangu_ultra_moe_mtp"
,
"pangu_ultra_moe_mtp"
,
"bailing_hybrid"
,
):
):
return
self
.
hf_text_config
.
kv_lora_rank
is
not
None
return
self
.
hf_text_config
.
kv_lora_rank
is
not
None
elif
self
.
hf_text_config
.
model_type
==
"eagle"
:
elif
self
.
hf_text_config
.
model_type
==
"eagle"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment