Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ab87f852
Unverified
Commit
ab87f852
authored
Feb 26, 2026
by
Jiangyun Zhu
Committed by
GitHub
Feb 26, 2026
Browse files
[Model] Ring 2.5 (#35102)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
3827c8c5
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1407 additions
and
70 deletions
+1407
-70
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/registry.py
tests/models/registry.py
+3
-0
vllm/model_executor/layers/fla/ops/layernorm_guard.py
vllm/model_executor/layers/fla/ops/layernorm_guard.py
+30
-5
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+1
-0
vllm/model_executor/layers/mamba/linear_attn.py
vllm/model_executor/layers/mamba/linear_attn.py
+124
-65
vllm/model_executor/models/bailing_moe_linear.py
vllm/model_executor/models/bailing_moe_linear.py
+1246
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/model_arch_config_convertor.py
vllm/transformers_utils/model_arch_config_convertor.py
+1
-0
No files found.
docs/models/supported_models.md
View file @
ab87f852
...
@@ -372,6 +372,7 @@ th {
...
@@ -372,6 +372,7 @@ th {
|
`BaiChuanForCausalLM`
| Baichuan2, Baichuan |
`baichuan-inc/Baichuan2-13B-Chat`
,
`baichuan-inc/Baichuan-7B`
, etc. | ✅︎ | ✅︎ |
|
`BaiChuanForCausalLM`
| Baichuan2, Baichuan |
`baichuan-inc/Baichuan2-13B-Chat`
,
`baichuan-inc/Baichuan-7B`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeForCausalLM`
| Ling |
`inclusionAI/Ling-lite-1.5`
,
`inclusionAI/Ling-plus`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeForCausalLM`
| Ling |
`inclusionAI/Ling-lite-1.5`
,
`inclusionAI/Ling-plus`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeV2ForCausalLM`
| Ling |
`inclusionAI/Ling-mini-2.0`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeV2ForCausalLM`
| Ling |
`inclusionAI/Ling-mini-2.0`
, etc. | ✅︎ | ✅︎ |
|
`BailingMoeV2_5ForCausalLM`
| Ling |
`inclusionAI/Ling-2.5-1T`
,
`inclusionAI/Ring-2.5-1T`
| | ✅︎ |
|
`BambaForCausalLM`
| Bamba |
`ibm-ai-platform/Bamba-9B-fp8`
,
`ibm-ai-platform/Bamba-9B`
| ✅︎ | ✅︎ |
|
`BambaForCausalLM`
| Bamba |
`ibm-ai-platform/Bamba-9B-fp8`
,
`ibm-ai-platform/Bamba-9B`
| ✅︎ | ✅︎ |
|
`BloomForCausalLM`
| BLOOM, BLOOMZ, BLOOMChat |
`bigscience/bloom`
,
`bigscience/bloomz`
, etc. | | ✅︎ |
|
`BloomForCausalLM`
| BLOOM, BLOOMZ, BLOOMChat |
`bigscience/bloom`
,
`bigscience/bloomz`
, etc. | | ✅︎ |
|
`ChatGLMModel`
,
`ChatGLMForConditionalGeneration`
| ChatGLM |
`zai-org/chatglm2-6b`
,
`zai-org/chatglm3-6b`
,
`thu-coai/ShieldLM-6B-chatglm3`
, etc. | ✅︎ | ✅︎ |
|
`ChatGLMModel`
,
`ChatGLMForConditionalGeneration`
| ChatGLM |
`zai-org/chatglm2-6b`
,
`zai-org/chatglm3-6b`
,
`thu-coai/ShieldLM-6B-chatglm3`
, etc. | ✅︎ | ✅︎ |
...
...
tests/models/registry.py
View file @
ab87f852
...
@@ -206,6 +206,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -206,6 +206,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"BailingMoeV2ForCausalLM"
:
_HfExamplesInfo
(
"BailingMoeV2ForCausalLM"
:
_HfExamplesInfo
(
"inclusionAI/Ling-mini-2.0"
,
trust_remote_code
=
True
"inclusionAI/Ling-mini-2.0"
,
trust_remote_code
=
True
),
),
"BailingMoeV2_5ForCausalLM"
:
_HfExamplesInfo
(
"inclusionAI/Ring-2.5-1T"
,
trust_remote_code
=
True
),
"BambaForCausalLM"
:
_HfExamplesInfo
(
"BambaForCausalLM"
:
_HfExamplesInfo
(
"ibm-ai-platform/Bamba-9B-v1"
,
"ibm-ai-platform/Bamba-9B-v1"
,
extras
=
{
"tiny"
:
"hmellor/tiny-random-BambaForCausalLM"
},
extras
=
{
"tiny"
:
"hmellor/tiny-random-BambaForCausalLM"
},
...
...
vllm/model_executor/layers/fla/ops/layernorm_guard.py
View file @
ab87f852
...
@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
...
@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
HAS_Z
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
NORM_BEFORE_GATE
:
tl
.
constexpr
,
NORM_BEFORE_GATE
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
):
):
# Map the program id to the starting row of X and Y it should compute.
# Map the program id to the starting row of X and Y it should compute.
row_start
=
tl
.
program_id
(
0
)
*
ROWS_PER_BLOCK
row_start
=
tl
.
program_id
(
0
)
*
ROWS_PER_BLOCK
...
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
...
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
if
HAS_Z
and
not
NORM_BEFORE_GATE
:
if
HAS_Z
and
not
NORM_BEFORE_GATE
:
Z_base
=
Z
+
rows
[:,
None
]
*
stride_z_row
+
col_offsets
Z_base
=
Z
+
rows
[:,
None
]
*
stride_z_row
+
col_offsets
z
=
tl
.
load
(
Z_base
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
z
=
tl
.
load
(
Z_base
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
ACTIVATION
==
"swish"
or
ACTIVATION
==
"silu"
:
x
*=
z
*
tl
.
sigmoid
(
z
)
x
*=
z
*
tl
.
sigmoid
(
z
)
elif
ACTIVATION
==
"sigmoid"
:
x
*=
tl
.
sigmoid
(
z
)
# Compute mean and variance per row (reduce along axis 1)
# Compute mean and variance per row (reduce along axis 1)
if
not
IS_RMS_NORM
:
if
not
IS_RMS_NORM
:
...
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
...
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
if
HAS_Z
and
NORM_BEFORE_GATE
:
if
HAS_Z
and
NORM_BEFORE_GATE
:
Z_base
=
Z
+
rows
[:,
None
]
*
stride_z_row
+
col_offsets
Z_base
=
Z
+
rows
[:,
None
]
*
stride_z_row
+
col_offsets
z
=
tl
.
load
(
Z_base
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
z
=
tl
.
load
(
Z_base
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
ACTIVATION
==
"swish"
or
ACTIVATION
==
"silu"
:
y
*=
z
*
tl
.
sigmoid
(
z
)
y
*=
z
*
tl
.
sigmoid
(
z
)
elif
ACTIVATION
==
"sigmoid"
:
y
*=
tl
.
sigmoid
(
z
)
# Write output
# Write output
tl
.
store
(
Y_base
,
y
,
mask
=
mask
)
tl
.
store
(
Y_base
,
y
,
mask
=
mask
)
...
@@ -178,6 +185,7 @@ def layer_norm_fwd(
...
@@ -178,6 +185,7 @@ def layer_norm_fwd(
group_size
:
int
=
None
,
group_size
:
int
=
None
,
norm_before_gate
:
bool
=
True
,
norm_before_gate
:
bool
=
True
,
is_rms_norm
:
bool
=
False
,
is_rms_norm
:
bool
=
False
,
activation
:
str
=
"swish"
,
):
):
M
,
N
=
x
.
shape
M
,
N
=
x
.
shape
if
group_size
is
None
:
if
group_size
is
None
:
...
@@ -232,9 +240,12 @@ def layer_norm_fwd(
...
@@ -232,9 +240,12 @@ def layer_norm_fwd(
eps
,
eps
,
BLOCK_N
=
BLOCK_N
,
BLOCK_N
=
BLOCK_N
,
ROWS_PER_BLOCK
=
rows_per_block
,
ROWS_PER_BLOCK
=
rows_per_block
,
HAS_BIAS
=
bias
is
not
None
,
HAS_Z
=
z
is
not
None
,
NORM_BEFORE_GATE
=
norm_before_gate
,
NORM_BEFORE_GATE
=
norm_before_gate
,
IS_RMS_NORM
=
is_rms_norm
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
ACTIVATION
=
activation
,
)
)
return
out
,
mean
,
rstd
return
out
,
mean
,
rstd
...
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
group_size
=
None
,
group_size
=
None
,
norm_before_gate
=
True
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
):
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
...
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
group_size
=
group_size
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
is_rms_norm
=
is_rms_norm
,
activation
=
activation
,
)
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
,
mean
,
rstd
,
z
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
,
mean
,
rstd
,
z
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
x_shape_og
=
x_shape_og
...
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
...
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
ctx
.
group_size
=
group_size
ctx
.
group_size
=
group_size
ctx
.
norm_before_gate
=
norm_before_gate
ctx
.
norm_before_gate
=
norm_before_gate
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
activation
=
activation
return
y
.
reshape
(
x_shape_og
)
return
y
.
reshape
(
x_shape_og
)
...
@@ -296,17 +310,25 @@ def layernorm_fn(
...
@@ -296,17 +310,25 @@ def layernorm_fn(
group_size
=
None
,
group_size
=
None
,
norm_before_gate
=
True
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
):
):
return
LayerNormFn
.
apply
(
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
,
activation
)
)
def
rmsnorm_fn
(
def
rmsnorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
activation
:
str
=
"swish"
,
):
):
return
LayerNormFn
.
apply
(
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
,
activation
)
)
...
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
...
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
norm_before_gate
:
bool
=
False
,
norm_before_gate
:
bool
=
False
,
device
:
torch
.
device
|
None
=
None
,
device
:
torch
.
device
|
None
=
None
,
dtype
:
torch
.
dtype
|
None
=
None
,
dtype
:
torch
.
dtype
|
None
=
None
,
activation
:
str
=
"swish"
,
):
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
...
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
...
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
eps
=
eps
self
.
eps
=
eps
self
.
activation
=
activation
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
self
.
group_size
=
group_size
self
.
group_size
=
group_size
...
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
...
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
eps
=
self
.
eps
,
eps
=
self
.
eps
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
norm_before_gate
=
self
.
norm_before_gate
,
norm_before_gate
=
self
.
norm_before_gate
,
activation
=
self
.
activation
,
)
)
vllm/model_executor/layers/layernorm.py
View file @
ab87f852
...
@@ -592,6 +592,7 @@ class RMSNormGated(CustomOp):
...
@@ -592,6 +592,7 @@ class RMSNormGated(CustomOp):
eps
=
self
.
eps
,
eps
=
self
.
eps
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
norm_before_gate
=
self
.
norm_before_gate
,
norm_before_gate
=
self
.
norm_before_gate
,
activation
=
self
.
activation
,
)
)
...
...
vllm/model_executor/layers/mamba/linear_attn.py
View file @
ab87f852
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
math
from
collections.abc
import
Callable
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
...
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
self
.
weight
.
weight_loader
=
self
.
weight_loader
self
.
weight
.
weight_loader
=
self
.
weight_loader
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
return
@
staticmethod
@
staticmethod
def
weight_loader
(
def
weight_loader
(
...
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
...
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
shard_size
=
loaded_weight
.
shape
[
0
]
//
tp_world
shard_size
=
loaded_weight
.
shape
[
0
]
//
tp_world
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
param
.
data
.
copy_
(
loaded_weight
[
shard
])
param
.
data
.
copy_
(
loaded_weight
[
shard
])
return
def
_forward
(
def
_forward
(
self
,
self
,
...
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
...
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
return
q
,
k
return
q
,
k
def
clear_linear_attention_cache_for_new_sequences
(
kv_cache
:
torch
.
Tensor
,
state_indices_tensor
:
torch
.
Tensor
,
attn_metadata
:
LinearAttentionMetadata
,
)
->
None
:
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
if
num_prefills
<=
0
:
return
num_decode_tokens
=
getattr
(
attn_metadata
,
"num_decode_tokens"
,
0
)
for
prefill_idx
in
range
(
num_prefills
):
q_start
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
]
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
+
1
]
query_len
=
q_end
-
q_start
context_len
=
(
attn_metadata
.
seq_lens
[
num_decode_tokens
+
prefill_idx
]
-
query_len
)
if
context_len
==
0
:
block_to_clear
=
state_indices_tensor
[
num_decode_tokens
+
prefill_idx
]
kv_cache
[
block_to_clear
,
...]
=
0
def
linear_attention_decode
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slope_rate
:
torch
.
Tensor
,
state_indices_tensor
:
torch
.
Tensor
,
q_start
:
int
=
0
,
q_end
:
int
|
None
=
None
,
slot_start
:
int
=
0
,
slot_end
:
int
|
None
=
None
,
block_size
:
int
=
32
,
)
->
torch
.
Tensor
:
q
=
q
[
q_start
:
q_end
].
unsqueeze
(
2
).
contiguous
()
k
=
k
[
q_start
:
q_end
].
unsqueeze
(
2
).
contiguous
()
v
=
v
[
q_start
:
q_end
].
unsqueeze
(
2
).
contiguous
()
slot_id
=
state_indices_tensor
[
slot_start
:
slot_end
]
return
linear_decode_forward_triton
(
q
,
k
,
v
,
kv_cache
,
slope_rate
,
slot_id
,
block_size
)
def
linear_attention_prefill_and_mix
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
state_indices_tensor
:
torch
.
Tensor
,
attn_metadata
:
LinearAttentionMetadata
,
slope_rate
:
torch
.
Tensor
,
block_size
:
int
,
decode_fn
:
Callable
[...,
torch
.
Tensor
],
prefix_fn
:
Callable
[...,
torch
.
Tensor
],
layer_idx
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
hidden
=
[]
for
_prefill_idx
in
range
(
getattr
(
attn_metadata
,
"num_prefills"
,
0
)):
if
_prefill_idx
>=
len
(
attn_metadata
.
query_start_loc
):
break
if
_prefill_idx
>=
len
(
state_indices_tensor
):
break
offset
=
attn_metadata
.
num_decode_tokens
_start
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
]
_end
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
+
1
]
slot_id
=
state_indices_tensor
[
offset
+
_prefill_idx
]
qs
=
q
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
ks
=
k
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
vs
=
v
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
slice_layer_cache
=
kv_cache
[
slot_id
,
...]
out_slice
=
prefix_fn
(
qs
,
ks
,
vs
,
slice_layer_cache
,
slope_rate
,
block_size
,
layer_idx
=
layer_idx
,
)
hidden
.
append
(
out_slice
.
contiguous
())
if
attn_metadata
.
num_decode_tokens
>
0
:
hidden_decode
=
decode_fn
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
hidden
.
insert
(
0
,
hidden_decode
)
if
not
hidden
:
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
hidden
=
torch
.
concat
(
hidden
,
dim
=
0
).
contiguous
()
return
hidden
class
MiniMaxText01LinearKernel
:
class
MiniMaxText01LinearKernel
:
@
staticmethod
@
staticmethod
def
jit_linear_forward_prefix
(
def
jit_linear_forward_prefix
(
...
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def
_prefill_and_mix_infer
(
def
_prefill_and_mix_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
):
hidden
=
[]
return
linear_attention_prefill_and_mix
(
for
_prefill_idx
in
range
(
getattr
(
attn_metadata
,
"num_prefills"
,
0
)):
q
=
q
,
if
_prefill_idx
>=
len
(
attn_metadata
.
query_start_loc
):
k
=
k
,
break
v
=
v
,
if
_prefill_idx
>=
len
(
state_indices_tensor
):
kv_cache
=
kv_cache
,
break
state_indices_tensor
=
state_indices_tensor
,
offset
=
attn_metadata
.
num_decode_tokens
attn_metadata
=
attn_metadata
,
_start
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
]
slope_rate
=
self
.
tp_slope
,
_end
=
attn_metadata
.
query_start_loc
[
offset
+
_prefill_idx
+
1
]
block_size
=
self
.
BLOCK
,
slot_id
=
state_indices_tensor
[
offset
+
_prefill_idx
]
decode_fn
=
self
.
_decode_infer
,
qs
=
q
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
prefix_fn
=
MiniMaxText01LinearKernel
.
jit_linear_forward_prefix
,
ks
=
k
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
vs
=
v
[
_start
:
_end
].
transpose
(
0
,
1
).
contiguous
()
slice_layer_cache
=
kv_cache
[
slot_id
,
...]
out_slice
=
MiniMaxText01LinearKernel
.
jit_linear_forward_prefix
(
qs
,
ks
,
vs
,
slice_layer_cache
,
self
.
tp_slope
,
self
.
BLOCK
,
layer_idx
=
self
.
layer_idx
,
layer_idx
=
self
.
layer_idx
,
)
)
hidden
.
append
(
out_slice
.
contiguous
())
if
attn_metadata
.
num_decode_tokens
>
0
:
hidden_decode
=
self
.
_decode_infer
(
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
)
hidden
.
insert
(
0
,
hidden_decode
)
if
not
hidden
:
return
torch
.
empty
((
0
,
q
.
size
(
-
1
)),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
hidden
=
torch
.
concat
(
hidden
,
dim
=
0
).
contiguous
()
return
hidden
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
def
_decode_infer
(
self
,
q
,
k
,
v
,
kv_cache
,
state_indices_tensor
,
attn_metadata
):
q
=
q
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
hidden
=
linear_attention_decode
(
k
=
k
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
q
,
v
=
v
[:
attn_metadata
.
num_decode_tokens
].
unsqueeze
(
2
).
contiguous
()
k
,
slot_id
=
state_indices_tensor
[:
attn_metadata
.
num_decodes
]
v
,
hidden
=
linear_decode_forward_triton
(
kv_cache
,
q
,
k
,
v
,
kv_cache
,
self
.
tp_slope
,
slot_id
,
32
self
.
tp_slope
,
state_indices_tensor
,
q_start
=
0
,
q_end
=
attn_metadata
.
num_decode_tokens
,
slot_start
=
0
,
slot_end
=
attn_metadata
.
num_decodes
,
block_size
=
32
,
)
)
return
hidden
return
hidden
...
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
...
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
if
attn_metadata
is
not
None
:
if
attn_metadata
is
not
None
:
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
][
0
]
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
state_indices_tensor
=
attn_metadata
.
state_indices_tensor
clear_linear_attention_cache_for_new_sequences
(
num_prefills
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
kv_cache
,
state_indices_tensor
,
attn_metadata
if
num_prefills
>
0
:
num_decode_tokens
=
getattr
(
attn_metadata
,
"num_decode_tokens"
,
0
)
for
prefill_idx
in
range
(
num_prefills
):
q_start
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
]
q_end
=
attn_metadata
.
query_start_loc
[
num_decode_tokens
+
prefill_idx
+
1
]
query_len
=
q_end
-
q_start
context_len
=
(
attn_metadata
.
seq_lens
[
num_decode_tokens
+
prefill_idx
]
-
query_len
)
)
if
context_len
==
0
:
block_to_clear
=
state_indices_tensor
[
num_decode_tokens
+
prefill_idx
]
kv_cache
[
block_to_clear
,
...]
=
0
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
decode_only
=
getattr
(
attn_metadata
,
"num_prefills"
,
0
)
==
0
if
attn_metadata
is
None
:
if
attn_metadata
is
None
:
...
...
vllm/model_executor/models/bailing_moe_linear.py
0 → 100644
View file @
ab87f852
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
ab87f852
...
@@ -81,6 +81,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -81,6 +81,7 @@ _TEXT_GENERATION_MODELS = {
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
"BailingMoeForCausalLM"
:
(
"bailing_moe"
,
"BailingMoeForCausalLM"
),
"BailingMoeForCausalLM"
:
(
"bailing_moe"
,
"BailingMoeForCausalLM"
),
"BailingMoeV2ForCausalLM"
:
(
"bailing_moe"
,
"BailingMoeV2ForCausalLM"
),
"BailingMoeV2ForCausalLM"
:
(
"bailing_moe"
,
"BailingMoeV2ForCausalLM"
),
"BailingMoeV2_5ForCausalLM"
:
(
"bailing_moe_linear"
,
"BailingMoeV25ForCausalLM"
),
"BambaForCausalLM"
:
(
"bamba"
,
"BambaForCausalLM"
),
"BambaForCausalLM"
:
(
"bamba"
,
"BambaForCausalLM"
),
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
...
...
vllm/transformers_utils/model_arch_config_convertor.py
View file @
ab87f852
...
@@ -245,6 +245,7 @@ class ModelArchConfigConvertorBase:
...
@@ -245,6 +245,7 @@ class ModelArchConfigConvertorBase:
"longcat_flash"
,
"longcat_flash"
,
"pangu_ultra_moe"
,
"pangu_ultra_moe"
,
"pangu_ultra_moe_mtp"
,
"pangu_ultra_moe_mtp"
,
"bailing_hybrid"
,
):
):
return
self
.
hf_text_config
.
kv_lora_rank
is
not
None
return
self
.
hf_text_config
.
kv_lora_rank
is
not
None
elif
self
.
hf_text_config
.
model_type
==
"eagle"
:
elif
self
.
hf_text_config
.
model_type
==
"eagle"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment