Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c826c72a
Unverified
Commit
c826c72a
authored
Jan 18, 2026
by
Li Xie
Committed by
GitHub
Jan 18, 2026
Browse files
[Model] Support Step1 Model (#32511)
Signed-off-by:
xieli
<
xieli@stepfun.com
>
parent
fe36bf5e
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
472 additions
and
6 deletions
+472
-6
docs/models/supported_models.md
docs/models/supported_models.md
+2
-1
tests/models/registry.py
tests/models/registry.py
+3
-0
tests/models/test_initialization.py
tests/models/test_initialization.py
+4
-1
vllm/attention/layer.py
vllm/attention/layer.py
+11
-1
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/model_executor/models/step1.py
vllm/model_executor/models/step1.py
+415
-0
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+4
-0
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+7
-1
vllm/v1/attention/ops/triton_unified_attention.py
vllm/v1/attention/ops/triton_unified_attention.py
+25
-2
No files found.
docs/models/supported_models.md
View file @
c826c72a
...
...
@@ -452,9 +452,10 @@ th {
|
`Qwen3MoeForCausalLM`
| Qwen3MoE |
`Qwen/Qwen3-30B-A3B`
, etc. | ✅︎ | ✅︎ |
|
`Qwen3NextForCausalLM`
| Qwen3NextMoE |
`Qwen/Qwen3-Next-80B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`SeedOssForCausalLM`
| SeedOss |
`ByteDance-Seed/Seed-OSS-36B-Instruct`
, etc. | ✅︎ | ✅︎ |
|
`SolarForCausalLM`
| Solar Pro |
`upstage/solar-pro-preview-instruct`
, etc. | ✅︎ | ✅︎ |
|
`StableLmForCausalLM`
| StableLM |
`stabilityai/stablelm-3b-4e1t`
,
`stabilityai/stablelm-base-alpha-7b-v2`
, etc. | | |
|
`Starcoder2ForCausalLM`
| Starcoder2 |
`bigcode/starcoder2-3b`
,
`bigcode/starcoder2-7b`
,
`bigcode/starcoder2-15b`
, etc. | | ✅︎ |
|
`S
olar
ForCausalLM`
| S
olar Pro |
`upstage/solar-pro-preview-instruct
`
, etc. | ✅︎ | ✅︎ |
|
`S
tep1
ForCausalLM`
| S
tep-Audio |
`stepfun-ai/Step-Audio-EditX
`
, etc. | ✅︎ | ✅︎ |
|
`TeleChat2ForCausalLM`
| TeleChat2 |
`Tele-AI/TeleChat2-3B`
,
`Tele-AI/TeleChat2-7B`
,
`Tele-AI/TeleChat2-35B`
, etc. | ✅︎ | ✅︎ |
|
`TeleFLMForCausalLM`
| TeleFLM |
`CofeAI/FLM-2-52B-Instruct-2407`
,
`CofeAI/Tele-FLM`
, etc. | ✅︎ | ✅︎ |
|
`XverseForCausalLM`
| XVERSE |
`xverse/XVERSE-7B-Chat`
,
`xverse/XVERSE-13B-Chat`
,
`xverse/XVERSE-65B-Chat`
, etc. | ✅︎ | ✅︎ |
...
...
tests/models/registry.py
View file @
c826c72a
...
...
@@ -472,6 +472,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"ByteDance-Seed/Seed-OSS-36B-Instruct"
,
trust_remote_code
=
True
,
),
"Step1ForCausalLM"
:
_HfExamplesInfo
(
"stepfun-ai/Step-Audio-EditX"
,
trust_remote_code
=
True
),
"SmolLM3ForCausalLM"
:
_HfExamplesInfo
(
"HuggingFaceTB/SmolLM3-3B"
),
"StableLMEpochForCausalLM"
:
_HfExamplesInfo
(
"stabilityai/stablelm-zephyr-3b"
),
"StableLmForCausalLM"
:
_HfExamplesInfo
(
"stabilityai/stablelm-3b-4e1t"
),
...
...
tests/models/test_initialization.py
View file @
c826c72a
...
...
@@ -115,8 +115,11 @@ def can_initialize(
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3.
# Step1ForCausalLM requires TRITON_ATTN for use_alibi_sqrt support.
attention_config
=
(
{
"backend"
:
"TRITON_ATTN"
}
if
model_arch
==
"GptOssForCausalLM"
else
None
{
"backend"
:
"TRITON_ATTN"
}
if
model_arch
in
(
"GptOssForCausalLM"
,
"Step1ForCausalLM"
)
else
None
)
if
model_arch
==
"WhisperForConditionalGeneration"
:
m
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
...
...
vllm/attention/layer.py
View file @
c826c72a
...
...
@@ -162,6 +162,7 @@ class Attention(nn.Module, AttentionLayerBase):
scale
:
float
,
num_kv_heads
:
int
|
None
=
None
,
alibi_slopes
:
list
[
float
]
|
None
=
None
,
use_alibi_sqrt
:
bool
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
logits_soft_cap
:
float
|
None
=
None
,
...
...
@@ -243,7 +244,16 @@ class Attention(nn.Module, AttentionLayerBase):
)
else
:
self
.
attn_backend
=
attn_backend
backend_supports_alibi_sqrt
=
self
.
attn_backend
.
supports_alibi_sqrt
()
use_alibi_sqrt
=
use_alibi_sqrt
if
use_alibi_sqrt
else
False
if
use_alibi_sqrt
and
not
backend_supports_alibi_sqrt
:
raise
ValueError
(
f
"use_alibi_sqrt is not supported by backend "
f
"
{
self
.
attn_backend
.
get_name
()
}
."
)
self
.
use_alibi_sqrt
=
bool
(
use_alibi_sqrt
)
if
backend_supports_alibi_sqrt
:
extra_impl_args
[
"use_alibi_sqrt"
]
=
self
.
use_alibi_sqrt
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if
(
...
...
vllm/model_executor/models/registry.py
View file @
c826c72a
...
...
@@ -185,6 +185,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3MoeForCausalLM"
:
(
"qwen3_moe"
,
"Qwen3MoeForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"SeedOssForCausalLM"
:
(
"seed_oss"
,
"SeedOssForCausalLM"
),
"Step1ForCausalLM"
:
(
"step1"
,
"Step1ForCausalLM"
),
"Step3TextForCausalLM"
:
(
"step3_text"
,
"Step3TextForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
...
...
vllm/model_executor/models/step1.py
0 → 100644
View file @
c826c72a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Shared Step decoder blocks and the Step1 text model."""
from
__future__
import
annotations
import
math
from
collections.abc
import
Iterable
import
torch
from
torch
import
nn
from
vllm.attention.layer
import
Attention
,
AttentionType
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsPP
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
)
from
vllm.sequence
import
IntermediateTensors
STEP_PACKED_MODULES_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
}
def
_get_step_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
"""Reference ALiBi slopes used by Step models."""
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
base
=
torch
.
tensor
(
2
**
(
-
8.0
/
closest_power_of_2
),
dtype
=
torch
.
float32
,
)
slopes
=
torch
.
pow
(
base
,
torch
.
arange
(
1
,
1
+
closest_power_of_2
,
dtype
=
torch
.
int32
),
)
if
closest_power_of_2
!=
total_num_heads
:
extra_base
=
torch
.
tensor
(
2
**
(
-
4.0
/
closest_power_of_2
),
dtype
=
torch
.
float32
,
)
num_remaining_heads
=
total_num_heads
-
closest_power_of_2
extra_powers
=
torch
.
arange
(
1
,
1
+
2
*
num_remaining_heads
,
2
,
dtype
=
torch
.
int32
,
)
slopes
=
torch
.
cat
(
[
slopes
,
torch
.
pow
(
extra_base
,
extra_powers
)],
dim
=
0
,
)
return
slopes
class
StepAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
total_num_kv_heads
=
getattr
(
config
,
"num_attention_groups"
,
getattr
(
config
,
"num_key_value_heads"
,
1
)
)
if
total_num_kv_heads
is
None
or
total_num_kv_heads
<=
0
:
total_num_kv_heads
=
1
self
.
total_num_kv_heads
=
total_num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
getattr
(
config
,
"attention_bias"
,
False
),
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
self
.
hidden_size
,
bias
=
getattr
(
config
,
"attention_bias"
,
False
),
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
alibi_slopes
=
_get_step_alibi_slopes
(
self
.
total_num_heads
)[
head_start
:
head_end
]
alibi_slopes
=
alibi_slopes
.
tolist
()
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
alibi_slopes
=
alibi_slopes
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_alibi_sqrt
=
True
,
attn_type
=
AttentionType
.
DECODER
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
StepMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
bias
:
bool
=
False
,
):
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
intermediate_size
,
intermediate_size
],
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
StepDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
StepAttention
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
StepMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
)
self
.
input_layernorm
=
RMSNorm
(
self
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
# type: ignore[name-defined]
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
StepDecoderModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
# Need embed_tokens on first rank, and also on last rank if tie_word_embeddings
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
StepDecoderLayer
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
prefix
=
maybe_prefix
(
prefix
,
"layers"
),
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
aux_hidden_state_layers
:
tuple
[
int
,
...]
=
getattr
(
config
,
"aux_hidden_state_layers"
,
()
)
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
,
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
IntermediateTensors
|
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
assert
input_ids
is
not
None
hidden_states
=
self
.
embed_input_ids
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
aux_hidden_states
=
[]
for
idx
,
layer
in
enumerate
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]):
if
idx
in
self
.
aux_hidden_state_layers
:
if
residual
is
None
:
aux_hidden_states
.
append
(
hidden_states
)
else
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
aux_hidden_states
:
return
hidden_states
,
aux_hidden_states
return
hidden_states
class
Step1ForCausalLM
(
nn
.
Module
,
SupportsPP
):
packed_modules_mapping
=
STEP_PACKED_MODULES_MAPPING
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
StepDecoderModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
)
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
getattr
(
config
,
"tie_word_embeddings"
,
True
):
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
None
# type: ignore[assignment]
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
IntermediateTensors
|
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
return
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
if
not
get_pp_group
().
is_last_rank
:
return
None
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/v1/attention/backend.py
View file @
c826c72a
...
...
@@ -172,6 +172,10 @@ class AttentionBackend(ABC):
def
supports_sink
(
cls
)
->
bool
:
return
False
@
classmethod
def
supports_alibi_sqrt
(
cls
)
->
bool
:
return
False
@
classmethod
def
supports_mm_prefix
(
cls
)
->
bool
:
return
False
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
c826c72a
...
...
@@ -331,6 +331,10 @@ class TritonAttentionBackend(AttentionBackend):
AttentionType
.
ENCODER_DECODER
,
)
@
classmethod
def
supports_alibi_sqrt
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_compute_capability
(
cls
,
capability
:
DeviceCapability
)
->
bool
:
return
True
...
...
@@ -353,6 +357,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
int
|
None
=
None
,
sinks
:
torch
.
Tensor
|
None
=
None
,
use_alibi_sqrt
:
bool
=
False
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
...
...
@@ -386,7 +391,7 @@ class TritonAttentionImpl(AttentionImpl):
f
"heads in the layer. Sinks shape:
{
sinks
.
shape
}
, "
f
"num_heads:
{
num_heads
}
."
)
self
.
use_alibi_sqrt
=
use_alibi_sqrt
self
.
supports_quant_query_input
=
current_platform
.
is_cuda
()
def
forward
(
...
...
@@ -513,6 +518,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
use_alibi_sqrt
=
self
.
use_alibi_sqrt
,
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
...
...
vllm/v1/attention/ops/triton_unified_attention.py
View file @
c826c72a
...
...
@@ -82,6 +82,7 @@ def kernel_unified_attention_2d(
HEAD_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
USE_ALIBI_SLOPES
:
tl
.
constexpr
,
# bool
USE_ALIBI_SQRT
:
tl
.
constexpr
,
# bool
USE_QQ_BIAS
:
tl
.
constexpr
,
# bool
USE_SOFTCAP
:
tl
.
constexpr
,
# bool
USE_SINKS
:
tl
.
constexpr
,
# bool
...
...
@@ -325,7 +326,16 @@ def kernel_unified_attention_2d(
)
if
USE_ALIBI_SLOPES
:
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
if
USE_ALIBI_SQRT
:
relative_pos
=
seq_offset
-
(
context_len
+
query_pos
[:,
None
])
alibi_offset
=
tl
.
where
(
relative_pos
<=
0
,
-
tl
.
sqrt
((
-
relative_pos
).
to
(
tl
.
float32
)),
0.0
,
)
else
:
alibi_offset
=
seq_offset
-
context_len
S
+=
alibi_slope
[:,
None
]
*
alibi_offset
if
USE_QQ_BIAS
:
# compute key positions relative to query section
...
...
@@ -420,6 +430,7 @@ def kernel_unified_attention_3d(
HEAD_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
USE_ALIBI_SLOPES
:
tl
.
constexpr
,
# bool
USE_ALIBI_SQRT
:
tl
.
constexpr
,
# bool
USE_QQ_BIAS
:
tl
.
constexpr
,
# bool
USE_SOFTCAP
:
tl
.
constexpr
,
# bool
USE_SINKS
:
tl
.
constexpr
,
# bool
...
...
@@ -669,7 +680,16 @@ def kernel_unified_attention_3d(
)
if
USE_ALIBI_SLOPES
:
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
if
USE_ALIBI_SQRT
:
relative_pos
=
seq_offset
-
(
context_len
+
query_pos
[:,
None
])
alibi_offset
=
tl
.
where
(
relative_pos
<=
0
,
-
tl
.
sqrt
((
-
relative_pos
).
to
(
tl
.
float32
)),
0.0
,
)
else
:
alibi_offset
=
seq_offset
-
context_len
S
+=
alibi_slope
[:,
None
]
*
alibi_offset
if
USE_QQ_BIAS
:
# compute key positions relative to query section
...
...
@@ -888,6 +908,7 @@ def unified_attention(
sinks
=
None
,
# Optional tensor for prefix lengths (PrefixLM support)
mm_prefix_range
=
None
,
use_alibi_sqrt
=
False
,
):
assert
causal
,
"Only causal attention is supported"
assert
q_descale
is
None
,
"Q scales not supported"
...
...
@@ -994,6 +1015,7 @@ def unified_attention(
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_ALIBI_SQRT
=
use_alibi_sqrt
,
USE_QQ_BIAS
=
use_qq_bias
,
USE_SOFTCAP
=
(
softcap
>
0
),
USE_SINKS
=
(
sinks
is
not
None
),
...
...
@@ -1045,6 +1067,7 @@ def unified_attention(
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
USE_ALIBI_SQRT
=
use_alibi_sqrt
,
USE_QQ_BIAS
=
use_qq_bias
,
USE_SOFTCAP
=
(
softcap
>
0
),
USE_SINKS
=
(
sinks
is
not
None
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment