Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a8c6fcf6
Commit
a8c6fcf6
authored
Aug 01, 2025
by
zhuwenwen
Browse files
Merge branch 'v0.9.2-dev' of
http://10.16.6.30/dcutoolkit/deeplearing/vllm
into v0.9.2-dev
parents
0480314d
66540380
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
2124 additions
and
1 deletion
+2124
-1
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/registry.py
tests/models/registry.py
+6
-0
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+2
-0
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
+296
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
vllm/model_executor/models/step3_text.py
vllm/model_executor/models/step3_text.py
+521
-0
vllm/model_executor/models/step3_vl.py
vllm/model_executor/models/step3_vl.py
+1052
-0
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+2
-0
vllm/reasoning/step3_reasoning_parser.py
vllm/reasoning/step3_reasoning_parser.py
+108
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+3
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+6
-0
vllm/transformers_utils/configs/step3_vl.py
vllm/transformers_utils/configs/step3_vl.py
+123
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+2
-1
No files found.
docs/models/supported_models.md
View file @
a8c6fcf6
...
@@ -598,6 +598,7 @@ Specified using `--task generate`.
...
@@ -598,6 +598,7 @@ Specified using `--task generate`.
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-7B`
| | ✅︎ | ✅︎
\*
|
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-7B`
| | ✅︎ | ✅︎
\*
|
|
`SkyworkR1VChatModel`
| Skywork-R1V-38B | T + I |
`Skywork/Skywork-R1V-38B`
| | ✅︎ | ✅︎ |
|
`SkyworkR1VChatModel`
| Skywork-R1V-38B | T + I |
`Skywork/Skywork-R1V-38B`
| | ✅︎ | ✅︎ |
|
`SmolVLMForConditionalGeneration`
| SmolVLM2 | T + I |
`SmolVLM2-2.2B-Instruct`
| ✅︎ | | ✅︎ |
|
`SmolVLMForConditionalGeneration`
| SmolVLM2 | T + I |
`SmolVLM2-2.2B-Instruct`
| ✅︎ | | ✅︎ |
|
`Step3VLForConditionalGeneration`
| Step3-VL | T + I
<sup>
+
</sup>
|
`stepfun-ai/step3`
| | ✅︎ | ✅︎ |
|
`TarsierForConditionalGeneration`
| Tarsier | T + I
<sup>
E+
</sup>
|
`omni-search/Tarsier-7b`
,
`omni-search/Tarsier-34b`
| | ✅︎ | ✅︎ |
|
`TarsierForConditionalGeneration`
| Tarsier | T + I
<sup>
E+
</sup>
|
`omni-search/Tarsier-7b`
,
`omni-search/Tarsier-34b`
| | ✅︎ | ✅︎ |
|
`Tarsier2ForConditionalGeneration`
<sup>
^
</sup>
| Tarsier2 | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`omni-research/Tarsier2-Recap-7b`
,
`omni-research/Tarsier2-7b-0115`
| | ✅︎ | ✅︎ |
|
`Tarsier2ForConditionalGeneration`
<sup>
^
</sup>
| Tarsier2 | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`omni-research/Tarsier2-Recap-7b`
,
`omni-research/Tarsier2-7b-0115`
| | ✅︎ | ✅︎ |
...
...
tests/models/registry.py
View file @
a8c6fcf6
...
@@ -266,6 +266,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -266,6 +266,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"StableLMEpochForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stabilityai/stablelm-zephyr-3b"
)),
# noqa: E501
"StableLMEpochForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stabilityai/stablelm-zephyr-3b"
)),
# noqa: E501
"StableLmForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stabilityai/stablelm-3b-4e1t"
)),
"StableLmForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stabilityai/stablelm-3b-4e1t"
)),
"Starcoder2ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"bigcode/starcoder2-3b"
)),
"Starcoder2ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"bigcode/starcoder2-3b"
)),
"Step3TextForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stepfun-ai/step3"
),
trust_remote_code
=
True
,
is_available_online
=
False
),
"SolarForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"upstage/solar-pro-preview-instruct"
)),
"SolarForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"upstage/solar-pro-preview-instruct"
)),
"TeleChat2ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Tele-AI/TeleChat2-3B"
),
"TeleChat2ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Tele-AI/TeleChat2-3B"
),
trust_remote_code
=
True
),
trust_remote_code
=
True
),
...
@@ -423,6 +426,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -423,6 +426,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2_5OmniForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Omni-7B-AWQ"
)),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Omni-7B-AWQ"
)),
# noqa: E501
"SkyworkR1VChatModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
)),
"SkyworkR1VChatModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
)),
"SmolVLMForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
)),
# noqa: E501
"SmolVLMForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
)),
# noqa: E501
"Step3VLForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stepfun-ai/step3"
),
trust_remote_code
=
True
,
is_available_online
=
False
),
"UltravoxModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
),
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
),
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"TarsierForConditionalGeneration"
:
_HfExamplesInfo
(
"omni-research/Tarsier-7b"
,
# noqa: E501
"TarsierForConditionalGeneration"
:
_HfExamplesInfo
(
"omni-research/Tarsier-7b"
,
# noqa: E501
...
...
vllm/entrypoints/openai/tool_parsers/__init__.py
View file @
a8c6fcf6
...
@@ -15,6 +15,7 @@ from .minimax_tool_parser import MinimaxToolParser
...
@@ -15,6 +15,7 @@ from .minimax_tool_parser import MinimaxToolParser
from
.mistral_tool_parser
import
MistralToolParser
from
.mistral_tool_parser
import
MistralToolParser
from
.phi4mini_tool_parser
import
Phi4MiniJsonToolParser
from
.phi4mini_tool_parser
import
Phi4MiniJsonToolParser
from
.pythonic_tool_parser
import
PythonicToolParser
from
.pythonic_tool_parser
import
PythonicToolParser
from
.step3_tool_parser
import
Step3ToolParser
from
.xlam_tool_parser
import
xLAMToolParser
from
.xlam_tool_parser
import
xLAMToolParser
__all__
=
[
__all__
=
[
...
@@ -31,6 +32,7 @@ __all__ = [
...
@@ -31,6 +32,7 @@ __all__ = [
"PythonicToolParser"
,
"PythonicToolParser"
,
"Phi4MiniJsonToolParser"
,
"Phi4MiniJsonToolParser"
,
"DeepSeekV3ToolParser"
,
"DeepSeekV3ToolParser"
,
"Step3ToolParser"
,
"xLAMToolParser"
,
"xLAMToolParser"
,
"MinimaxToolParser"
,
"MinimaxToolParser"
,
"Glm4MoeModelToolParser"
,
"Glm4MoeModelToolParser"
,
...
...
vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py
0 → 100644
View file @
a8c6fcf6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
json
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
,
Union
import
regex
as
re
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
ToolParserManager
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
@
ToolParserManager
.
register_module
([
"step3"
])
class
Step3ToolParser
(
ToolParser
):
"""
Tool parser for a model that uses a specific XML-like format for tool calls.
This version uses a robust, stateful, cursor-based streaming parser and
consolidates tool arguments into a single message.
"""
TOOL_CALLS_BEGIN
=
"<|tool_calls_begin|>"
TOOL_CALLS_END
=
"<|tool_calls_end|>"
TOOL_CALL_BEGIN
=
"<|tool_call_begin|>"
TOOL_CALL_END
=
"<|tool_call_end|>"
TOOL_SEP
=
"<|tool_sep|>"
SPECIAL_TOKENS
=
[
TOOL_CALLS_BEGIN
,
TOOL_CALLS_END
,
TOOL_CALL_BEGIN
,
TOOL_CALL_END
]
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
(
tokenizer
)
self
.
position
=
0
# Explicit state flags for robust streaming
self
.
tool_block_started
=
False
self
.
tool_block_finished
=
False
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
)
->
ChatCompletionRequest
:
if
request
.
tools
and
request
.
tool_choice
!=
'none'
:
request
.
skip_special_tokens
=
False
return
request
@
staticmethod
def
_parse_steptml_invoke
(
action_text
:
str
)
->
tuple
[
Optional
[
str
],
Optional
[
dict
[
str
,
str
]]]:
func_name_match
=
re
.
search
(
r
'<steptml:invoke name="([^"]+)">'
,
action_text
)
if
not
func_name_match
:
return
None
,
None
func_name
=
func_name_match
.
group
(
1
)
params
:
dict
[
str
,
str
]
=
{}
param_matches
=
re
.
findall
(
r
'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>'
,
action_text
)
for
name
,
value
in
param_matches
:
params
[
name
]
=
value
.
strip
()
return
func_name
,
params
def
_cast_arguments
(
self
,
func_name
:
str
,
params
:
dict
[
str
,
Any
],
request
:
ChatCompletionRequest
,
)
->
dict
[
str
,
Any
]:
for
tool
in
request
.
tools
or
[]:
if
tool
.
function
.
name
==
func_name
:
schema
=
tool
.
function
.
parameters
or
{}
properties
=
schema
.
get
(
"properties"
,
{})
for
key
,
value
in
params
.
items
():
if
not
isinstance
(
value
,
str
):
continue
prop
=
properties
.
get
(
key
,
{})
typ
=
prop
.
get
(
"type"
)
if
typ
==
"string"
:
params
[
key
]
=
value
.
strip
()
elif
typ
==
"integer"
:
with
contextlib
.
suppress
(
ValueError
):
params
[
key
]
=
int
(
value
)
elif
typ
==
"number"
:
with
contextlib
.
suppress
(
ValueError
):
params
[
key
]
=
float
(
value
)
elif
typ
==
"boolean"
:
lower_val
=
value
.
lower
()
params
[
key
]
=
lower_val
==
"true"
if
lower_val
in
(
"true"
,
"false"
)
else
value
elif
typ
==
"null"
:
params
[
key
]
=
None
if
value
.
lower
(
)
==
"null"
else
value
break
return
params
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
Union
[
DeltaMessage
,
None
]:
# The main loop processes the stream from the last known position.
while
True
:
if
self
.
position
>=
len
(
current_text
):
return
None
# We've processed the entire stream.
unprocessed_text
=
current_text
[
self
.
position
:]
# STATE: After all tools are done, all subsequent text is content.
if
self
.
tool_block_finished
:
self
.
position
=
len
(
current_text
)
return
DeltaMessage
(
content
=
unprocessed_text
)
# STATE: Before the tool block has started.
if
not
self
.
tool_block_started
:
if
unprocessed_text
.
startswith
(
self
.
TOOL_CALLS_BEGIN
):
self
.
position
+=
len
(
self
.
TOOL_CALLS_BEGIN
)
self
.
tool_block_started
=
True
continue
# Token consumed, re-loop.
start_pos
=
unprocessed_text
.
find
(
self
.
TOOL_CALLS_BEGIN
)
if
start_pos
==
-
1
:
if
self
.
TOOL_CALLS_BEGIN
.
startswith
(
unprocessed_text
.
strip
())
and
unprocessed_text
:
return
None
# It's a prefix, wait.
self
.
position
=
len
(
current_text
)
return
DeltaMessage
(
content
=
unprocessed_text
)
else
:
content
=
unprocessed_text
[:
start_pos
]
self
.
position
+=
len
(
content
)
return
DeltaMessage
(
content
=
content
)
# STATE: Inside the main tool block.
offset
=
len
(
unprocessed_text
)
-
len
(
unprocessed_text
.
lstrip
())
unprocessed_text
=
unprocessed_text
.
lstrip
()
self
.
position
+=
offset
if
unprocessed_text
.
startswith
(
self
.
TOOL_CALLS_END
):
self
.
position
+=
len
(
self
.
TOOL_CALLS_END
)
self
.
tool_block_finished
=
True
self
.
current_tool_id
=
-
1
continue
# Check if we are between tool calls.
tool_finished
=
(
self
.
current_tool_id
!=
-
1
and
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"finished"
))
if
self
.
current_tool_id
==
-
1
or
tool_finished
:
if
unprocessed_text
.
startswith
(
self
.
TOOL_CALL_BEGIN
):
self
.
position
+=
len
(
self
.
TOOL_CALL_BEGIN
)
if
self
.
current_tool_id
==
-
1
:
self
.
current_tool_id
=
0
else
:
self
.
current_tool_id
+=
1
self
.
current_tool_name_sent
=
False
while
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
self
.
prev_tool_call_arr
[
self
.
current_tool_id
][
"finished"
]
=
False
continue
if
self
.
TOOL_CALL_BEGIN
.
startswith
(
unprocessed_text
):
return
None
# STATE: Parsing an active tool call.
if
self
.
current_tool_id
!=
-
1
and
not
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"finished"
,
False
):
end_tool_pos
=
unprocessed_text
.
find
(
self
.
TOOL_CALL_END
)
if
end_tool_pos
==
-
1
:
tool_body
=
unprocessed_text
else
:
tool_body
=
unprocessed_text
[:
end_tool_pos
]
if
end_tool_pos
==
-
1
and
self
.
TOOL_CALL_END
.
startswith
(
tool_body
):
return
None
function_name
,
arguments
=
self
.
_parse_steptml_invoke
(
tool_body
)
if
not
function_name
:
return
None
tool_call_arr
=
{
"name"
:
function_name
,
"parameters"
:
arguments
or
{}
}
# Send the function name as soon as it's parsed.
if
not
self
.
current_tool_name_sent
:
self
.
current_tool_name_sent
=
True
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
update
(
tool_call_arr
)
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
f
"chatcmpl-tool-
{
random_uuid
()
}
"
,
function
=
DeltaFunctionCall
(
name
=
function_name
))
])
# Update our internal state with the latest parsed arguments.
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
update
(
# noqa: E501
tool_call_arr
)
# Only send arguments when the tool call is complete.
if
end_tool_pos
!=
-
1
:
self
.
position
+=
end_tool_pos
+
len
(
self
.
TOOL_CALL_END
)
self
.
prev_tool_call_arr
[
self
.
current_tool_id
][
"finished"
]
=
True
final_args
=
self
.
_cast_arguments
(
function_name
,
tool_call_arr
.
get
(
"parameters"
,
{}),
# type: ignore
request
)
if
final_args
:
final_args_json
=
json
.
dumps
(
final_args
,
ensure_ascii
=
False
)
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
final_args_json
))
])
# If tool is not finished, return None to wait for more tokens.
return
None
return
None
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
if
self
.
TOOL_CALLS_BEGIN
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
pre_text
,
rest
=
model_output
.
split
(
self
.
TOOL_CALLS_BEGIN
,
1
)
if
self
.
TOOL_CALLS_END
not
in
rest
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
tool_block
,
post_text
=
rest
.
split
(
self
.
TOOL_CALLS_END
,
1
)
content
=
(
pre_text
+
post_text
).
strip
()
tool_calls
:
list
[
ToolCall
]
=
[]
call_parts
=
tool_block
.
split
(
self
.
TOOL_CALL_BEGIN
)
for
part
in
call_parts
:
if
not
part
or
self
.
TOOL_CALL_END
not
in
part
:
continue
call_content
=
part
.
split
(
self
.
TOOL_CALL_END
,
1
)[
0
]
if
self
.
TOOL_SEP
not
in
call_content
:
continue
type_part
,
invoke_part
=
call_content
.
split
(
self
.
TOOL_SEP
,
1
)
if
type_part
.
strip
()
!=
"function"
:
continue
function_name
,
params_dict
=
self
.
_parse_steptml_invoke
(
invoke_part
)
if
function_name
and
params_dict
is
not
None
:
params_dict
=
self
.
_cast_arguments
(
function_name
,
params_dict
,
request
)
params_str
=
json
.
dumps
(
params_dict
,
ensure_ascii
=
False
)
tool_calls
.
append
(
ToolCall
(
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
params_str
)))
if
tool_calls
:
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
if
content
else
None
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
a8c6fcf6
...
@@ -120,6 +120,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -120,6 +120,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3ForCausalLM"
:
(
"qwen3"
,
"Qwen3ForCausalLM"
),
"Qwen3ForCausalLM"
:
(
"qwen3"
,
"Qwen3ForCausalLM"
),
"Qwen3MoeForCausalLM"
:
(
"qwen3_moe"
,
"Qwen3MoeForCausalLM"
),
"Qwen3MoeForCausalLM"
:
(
"qwen3_moe"
,
"Qwen3MoeForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"Step3TextForCausalLM"
:
(
"step3_text"
,
"Step3TextForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"Starcoder2ForCausalLM"
:
(
"starcoder2"
,
"Starcoder2ForCausalLM"
),
"Starcoder2ForCausalLM"
:
(
"starcoder2"
,
"Starcoder2ForCausalLM"
),
...
@@ -228,6 +229,7 @@ _MULTIMODAL_MODELS = {
...
@@ -228,6 +229,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"Step3VLForConditionalGeneration"
:
(
"step3_vl"
,
"Step3VLForConditionalGeneration"
),
# noqa: E501
"Phi4MMForCausalLM"
:
(
"phi4mm"
,
"Phi4MMForCausalLM"
),
"Phi4MMForCausalLM"
:
(
"phi4mm"
,
"Phi4MMForCausalLM"
),
"TarsierForConditionalGeneration"
:
(
"tarsier"
,
"TarsierForConditionalGeneration"
),
# noqa: E501
"TarsierForConditionalGeneration"
:
(
"tarsier"
,
"TarsierForConditionalGeneration"
),
# noqa: E501
"Tarsier2ForConditionalGeneration"
:
(
"qwen2_vl"
,
"Tarsier2ForConditionalGeneration"
),
# noqa: E501
"Tarsier2ForConditionalGeneration"
:
(
"qwen2_vl"
,
"Tarsier2ForConditionalGeneration"
),
# noqa: E501
...
...
vllm/model_executor/models/step3_text.py
0 → 100644
View file @
a8c6fcf6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model."""
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
logger
=
init_logger
(
__name__
)
class
FusedMoEBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ModelConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
>
config
.
moe_num_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
moe_num_experts
}
."
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_top_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_expert_weight
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
moe_num_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
orig_shape
=
hidden_states
.
shape
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
orig_shape
)
class
Step3TextMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
hidden_size
=
hidden_size
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
hidden_states
)
intermediate_act
=
self
.
act_fn
(
gate_up
)
output
,
_
=
self
.
down_proj
(
intermediate_act
)
return
output
class
Step3TextAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
norm_eps
:
float
,
rope_theta
:
int
,
share_q_dim
:
Optional
[
int
]
=
None
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
max_position_embedding
:
int
=
8192
,
head_dim
:
int
=
256
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
if
num_kv_heads
!=
1
:
raise
ValueError
(
f
"Step3TextAttention num_kv_heads must be 1, "
f
"but got
{
num_kv_heads
}
."
)
self
.
num_kv_heads
=
num_kv_heads
self
.
head_dim
=
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
q_size
=
share_q_dim
if
share_q_dim
else
self
.
head_dim
self
.
qkv_proj
=
ReplicatedLinear
(
hidden_size
,
self
.
q_size
+
self
.
kv_size
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
inter_norm
=
RMSNorm
(
self
.
q_size
,
eps
=
norm_eps
)
self
.
wq
=
ColumnParallelLinear
(
self
.
q_size
,
self
.
head_dim
*
self
.
total_num_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.wq"
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embedding
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
)
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
self
.
num_kv_heads
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
self
.
inter_norm
(
q
)
q
=
self
.
wq
(
q
)[
0
]
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
residual
,
_
=
self
.
o_proj
(
attn_output
)
return
residual
class
Step3TextDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
config
.
hf_config
self
.
hidden_size
=
config
.
hidden_size
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
self_attn
=
Step3TextAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
1
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
norm_eps
=
config
.
rms_norm_eps
,
max_position_embedding
=
config
.
max_position_embedding
,
head_dim
=
config
.
head_dim
,
share_q_dim
=
config
.
share_q_dim
,
rope_theta
=
config
.
rope_theta
,
rope_scaling
=
rope_scaling
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
layer_idx
=
int
(
prefix
.
split
(
"layers."
)[
1
].
split
(
"."
)[
0
])
moe_layers_enum
=
getattr
(
config
,
"moe_layers_enum"
,
None
)
if
moe_layers_enum
is
not
None
:
moe_layers_idx
=
[
int
(
i
)
for
i
in
moe_layers_enum
.
strip
().
split
(
','
)
]
else
:
# Default to 1dense.
moe_layers_idx
=
[
i
for
i
in
range
(
1
,
config
.
num_hidden_layers
)]
if
layer_idx
in
moe_layers_idx
:
self
.
moe
=
FusedMoEBlock
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.moe"
)
self
.
share_expert
=
Step3TextMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
share_expert_dim
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.share_expert"
)
self
.
use_moe
=
True
else
:
self
.
mlp
=
Step3TextMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
use_moe
=
False
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
if
self
.
use_moe
:
share_output
=
self
.
share_expert
(
hidden_states
)
moe_output
=
self
.
moe
(
hidden_states
)
hidden_states
=
share_output
+
moe_output
else
:
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
@
support_torch_compile
class
Step3TextModel
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
vocab_size
=
config
.
vocab_size
self
.
config
=
config
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Step3TextDecoderLayer
(
config
=
vllm_config
.
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Step3TextForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
vllm_config
=
vllm_config
self
.
model
=
Step3TextModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
qkv_params_mapping
=
[
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(
".qkv_proj"
,
".q_proj"
,
0
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".k_proj"
,
self
.
config
.
share_q_dim
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
(
".qkv_proj"
,
".v_proj"
,
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
),
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)
/
(
self
.
config
.
share_q_dim
+
self
.
config
.
head_dim
*
2
)),
]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.weight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.weight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.weight"
,
"w2"
)
]
disable_moe_stacked_params
=
[
data
[
1
]
for
data
in
expert_params_mapping
]
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
any
(
disable_moe_stacked_param
in
name
for
disable_moe_stacked_param
in
disable_moe_stacked_params
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
for
expert_id
in
range
(
loaded_weight
.
shape
[
0
]):
loaded_weight_expert
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
loaded_weight_expert
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
loaded_params
.
add
(
name
)
break
else
:
for
(
param_name
,
weight_name
,
start_idx
,
end_idx
)
in
qkv_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
dim
=
param
.
shape
[
param
.
output_dim
]
begin_idx
=
int
(
start_idx
*
dim
)
end_idx
=
int
(
end_idx
*
dim
)
param_slice
=
param
.
narrow
(
param
.
output_dim
,
begin_idx
,
end_idx
-
begin_idx
)
param_slice
.
copy_
(
loaded_weight
)
loaded_params
.
add
(
name
)
break
else
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
\ No newline at end of file
vllm/model_executor/models/step3_vl.py
0 → 100644
View file @
a8c6fcf6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
itertools
import
product
from
math
import
ceil
,
sqrt
from
typing
import
Any
,
Literal
,
Optional
,
TypedDict
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision.transforms.functional
import
InterpolationMode
from
transformers
import
BatchFeature
,
PretrainedConfig
,
TensorType
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
ImageSize
,
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
Step3VisionEncoderConfig
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
class
Step3VLImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
patch_pixel_values
:
Optional
[
torch
.
Tensor
]
num_patches
:
list
[
int
]
class
Step3VLImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
image_embeds
:
torch
.
Tensor
Step3VLImageInputs
=
Union
[
Step3VLImagePixelInputs
,
Step3VLImageEmbeddingInputs
]
ImageWithPatches
=
tuple
[
Image
.
Image
,
list
[
Image
.
Image
],
list
[
int
]
|
None
]
MAX_IMAGE_SIZE
:
int
=
3024
class
Step3VisionProcessor
:
def
__init__
(
self
,
size
,
interpolation_mode
=
"bicubic"
,
patch_size
=
None
):
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
patch_size
=
patch_size
if
patch_size
is
not
None
else
size
self
.
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
,
std
),
transforms
.
Resize
(
(
size
,
size
),
interpolation
=
InterpolationMode
.
BICUBIC
if
interpolation_mode
==
"bicubic"
else
InterpolationMode
.
BILINEAR
,
antialias
=
True
),
])
self
.
patch_transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
,
std
),
transforms
.
Resize
(
(
patch_size
,
patch_size
),
interpolation
=
InterpolationMode
.
BICUBIC
if
interpolation_mode
==
"bicubic"
else
InterpolationMode
.
BILINEAR
,
antialias
=
True
),
])
if
patch_size
is
not
None
else
None
def
__call__
(
self
,
image
,
is_patch
=
False
):
if
is_patch
:
return
{
"pixel_values"
:
self
.
patch_transform
(
image
).
unsqueeze
(
0
)}
else
:
return
{
"pixel_values"
:
self
.
transform
(
image
).
unsqueeze
(
0
)}
class
ImagePatcher
:
def
determine_window_size
(
self
,
long
:
int
,
short
:
int
)
->
int
:
if
long
<=
728
:
return
short
if
long
/
short
>
1.5
else
0
return
min
(
short
,
504
)
if
long
/
short
>
4
else
504
def
slide_window
(
self
,
width
:
int
,
height
:
int
,
sizes
:
list
[
tuple
[
int
,
int
]],
steps
:
list
[
tuple
[
int
,
int
]],
img_rate_thr
:
float
=
0.6
,
)
->
tuple
[
list
[
tuple
[
int
,
int
,
int
,
int
]],
tuple
[
int
,
int
]]:
assert
1
>=
img_rate_thr
>=
0
,
"The `in_rate_thr` should lie in 0~1"
windows
=
[]
# Sliding windows.
for
size
,
step
in
zip
(
sizes
,
steps
):
size_w
,
size_h
=
size
step_w
,
step_h
=
step
x_num
=
1
if
width
<=
size_w
else
ceil
((
width
-
size_w
)
/
step_w
+
1
)
x_start
=
[
step_w
*
i
for
i
in
range
(
x_num
)]
if
len
(
x_start
)
>
1
and
x_start
[
-
1
]
+
size_w
>
width
:
x_start
[
-
1
]
=
width
-
size_w
y_num
=
1
if
height
<=
size_h
else
ceil
((
height
-
size_h
)
/
step_h
+
1
)
y_start
=
[
step_h
*
i
for
i
in
range
(
y_num
)]
if
len
(
y_start
)
>
1
and
y_start
[
-
1
]
+
size_h
>
height
:
y_start
[
-
1
]
=
height
-
size_h
start
=
np
.
array
(
list
(
product
(
y_start
,
x_start
)),
dtype
=
int
)
start
[:,
[
0
,
1
]]
=
start
[:,
[
1
,
0
]]
windows
.
append
(
np
.
concatenate
([
start
,
start
+
size
],
axis
=
1
))
windows
=
np
.
concatenate
(
windows
,
axis
=
0
)
return
[(
int
(
box
[
0
]),
int
(
box
[
1
]),
int
(
box
[
2
]
-
box
[
0
]),
int
(
box
[
3
]
-
box
[
1
]))
for
box
in
windows
],
(
x_num
,
y_num
)
def
square_pad
(
self
,
img
:
Image
.
Image
)
->
Image
.
Image
:
w
,
h
=
img
.
size
if
w
==
h
:
return
img
size
=
max
(
w
,
h
)
padded
=
Image
.
new
(
img
.
mode
,
(
size
,
size
),
0
)
padded
.
paste
(
img
,
(
0
,
0
))
return
padded
def
get_image_size_for_padding
(
self
,
img_width
:
int
,
img_height
:
int
)
->
tuple
[
int
,
int
]:
ratio
=
img_width
/
img_height
if
min
(
img_height
,
img_width
)
<
32
and
(
ratio
>
4
or
ratio
<
1
/
4
):
new_size
=
max
(
img_height
,
img_width
)
return
new_size
,
new_size
return
img_width
,
img_height
def
get_image_size_for_preprocess
(
self
,
img_width
:
int
,
img_height
:
int
)
->
tuple
[
int
,
int
]:
if
max
(
img_height
,
img_width
)
>
MAX_IMAGE_SIZE
:
scale_factor
=
MAX_IMAGE_SIZE
/
max
(
img_height
,
img_width
)
img_width
=
int
(
img_width
*
scale_factor
)
img_height
=
int
(
img_height
*
scale_factor
)
return
img_width
,
img_height
def
get_image_size_for_crop
(
self
,
img_width
:
int
,
img_height
:
int
,
window_size
:
int
):
w_ratio
=
img_width
/
window_size
h_ratio
=
img_height
/
window_size
if
w_ratio
<
1
:
width_new
=
img_width
else
:
decimal_w
=
w_ratio
-
img_width
//
window_size
w_ratio
=
int
(
w_ratio
)
+
1
if
decimal_w
>
0.2
else
int
(
w_ratio
)
width_new
=
window_size
*
w_ratio
if
h_ratio
<
1
:
height_new
=
img_height
else
:
decimal_h
=
h_ratio
-
img_height
//
window_size
h_ratio
=
int
(
h_ratio
)
+
1
if
decimal_h
>
0.2
else
int
(
h_ratio
)
height_new
=
window_size
*
h_ratio
return
int
(
width_new
),
int
(
height_new
)
def
patch_crop
(
self
,
img
:
Image
.
Image
,
i
:
int
,
j
:
int
,
th
:
int
,
tw
:
int
):
target
=
img
.
crop
((
j
,
i
,
j
+
tw
,
i
+
th
))
return
target
def
get_num_patches
(
self
,
img_width
:
int
,
img_height
:
int
)
->
tuple
[
int
,
int
]:
img_width
,
img_height
=
self
.
get_image_size_for_padding
(
img_width
,
img_height
)
img_width
,
img_height
=
self
.
get_image_size_for_preprocess
(
img_width
,
img_height
)
window_size
=
self
.
determine_window_size
(
max
(
img_height
,
img_width
),
min
(
img_height
,
img_width
))
if
window_size
==
0
:
return
0
,
0
else
:
img_width
,
img_height
=
self
.
get_image_size_for_crop
(
img_width
,
img_height
,
window_size
)
center_list
,
(
x_num
,
y_num
)
=
self
.
slide_window
(
img_width
,
img_height
,
[(
window_size
,
window_size
)],
[(
window_size
,
window_size
)])
full_rows
=
(
len
(
center_list
)
-
1
)
//
x_num
+
1
if
len
(
center_list
)
>
0
and
len
(
center_list
)
%
x_num
==
0
:
full_rows
-=
1
return
len
(
center_list
),
full_rows
def
__call__
(
self
,
img
:
Image
.
Image
)
->
tuple
[
Image
.
Image
,
list
[
Image
.
Image
],
list
[
bool
]
|
None
]:
img_width
,
img_height
=
img
.
size
new_img_width
,
new_img_height
=
self
.
get_image_size_for_padding
(
img_width
,
img_height
)
if
new_img_width
!=
img_width
or
new_img_height
!=
img_height
:
img
=
self
.
square_pad
(
img
)
img_width
,
img_height
=
img
.
size
new_img_width
,
new_img_height
=
self
.
get_image_size_for_preprocess
(
img_width
,
img_height
)
img
=
img
.
resize
((
new_img_width
,
new_img_height
),
Image
.
Resampling
.
BILINEAR
)
window_size
=
self
.
determine_window_size
(
max
(
new_img_height
,
new_img_width
),
min
(
new_img_height
,
new_img_width
))
if
window_size
==
0
:
return
img
,
[],
None
else
:
new_img_width
,
new_img_height
=
self
.
get_image_size_for_crop
(
new_img_width
,
new_img_height
,
window_size
)
if
(
new_img_width
,
new_img_height
)
!=
(
img_width
,
img_height
):
img_for_crop
=
img
.
resize
((
new_img_width
,
new_img_height
),
Image
.
Resampling
.
BILINEAR
)
else
:
img_for_crop
=
img
patches
=
[]
newlines
=
[]
center_list
,
(
x_num
,
y_num
)
=
self
.
slide_window
(
new_img_width
,
new_img_height
,
[(
window_size
,
window_size
)],
[(
window_size
,
window_size
)])
for
patch_id
,
center_lf_point
in
enumerate
(
center_list
):
x
,
y
,
patch_w
,
patch_h
=
center_lf_point
big_patch
=
self
.
patch_crop
(
img_for_crop
,
y
,
x
,
patch_h
,
patch_w
)
patches
.
append
(
big_patch
)
if
(
patch_id
+
1
)
%
x_num
==
0
:
newlines
.
append
(
patch_id
)
if
newlines
and
newlines
[
-
1
]
==
len
(
patches
)
-
1
:
newlines
.
pop
()
return
img
,
patches
,
[
i
in
newlines
for
i
in
range
(
len
(
patches
))
]
if
len
(
patches
)
>
0
else
None
class
Step3VLProcessor
:
def
__init__
(
self
,
config
:
PretrainedConfig
,
tokenizer
:
AnyTokenizer
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
tokenizer
=
tokenizer
self
.
image_size
=
728
self
.
patch_size
=
504
self
.
image_preprocessor
=
Step3VisionProcessor
(
self
.
image_size
,
"bilinear"
,
self
.
patch_size
)
self
.
num_image_feature_size
=
169
self
.
num_patch_feature_size
=
81
self
.
image_token
=
"<im_patch>"
self
.
image_feature_placeholder
=
(
self
.
image_token
*
self
.
num_image_feature_size
)
self
.
patch_feature_placeholder
=
(
self
.
image_token
*
self
.
num_patch_feature_size
)
self
.
patcher
=
ImagePatcher
()
@
property
def
image_token_id
(
self
)
->
int
:
return
self
.
tokenizer
.
get_vocab
()[
self
.
image_token
]
def
get_num_image_tokens
(
self
,
img_width
:
int
,
img_height
:
int
)
->
int
:
num_patches
,
num_newlines
=
self
.
patcher
.
get_num_patches
(
img_width
,
img_height
)
return
num_patches
*
(
self
.
num_patch_feature_size
+
2
)
+
self
.
num_image_feature_size
+
2
+
num_newlines
def
_split_images
(
self
,
images
:
list
[
Image
.
Image
])
->
list
[
ImageWithPatches
]:
result
=
[]
for
img
in
images
:
result
.
append
(
self
.
patcher
(
img
))
return
result
def
_convert_images_to_pixel_values
(
self
,
images
:
list
[
Image
.
Image
],
is_patch
:
bool
=
False
,
)
->
list
[
torch
.
Tensor
]:
return
[
self
.
image_preprocessor
(
img
,
is_patch
=
is_patch
)[
"pixel_values"
]
for
img
in
images
]
def
_get_patch_repl
(
self
,
num_patches
:
int
,
patch_newline_mask
:
list
[
bool
]
|
None
,
)
->
tuple
[
str
,
list
[
int
]]:
text
=
""
token_ids
=
[]
for
i
in
range
(
num_patches
):
assert
len
(
patch_newline_mask
)
==
num_patches
text
+=
f
"<patch_start>
{
self
.
patch_feature_placeholder
}
<patch_end>"
token_ids
.
extend
(
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_start>"
)]
+
[
self
.
image_token_id
]
*
self
.
num_patch_feature_size
+
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_end>"
)])
if
patch_newline_mask
and
patch_newline_mask
[
i
]:
text
+=
"<patch_newline>"
token_ids
.
append
(
self
.
tokenizer
.
convert_tokens_to_ids
(
"<patch_newline>"
))
return
text
,
token_ids
def
_get_image_repl
(
self
,
num_images
:
int
,
)
->
tuple
[
str
,
list
[
int
]]:
text
=
f
"<im_start>
{
self
.
image_feature_placeholder
}
<im_end>"
token_ids
=
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<im_start>"
)
]
+
[
self
.
image_token_id
]
*
self
.
num_image_feature_size
+
[
self
.
tokenizer
.
convert_tokens_to_ids
(
"<im_end>"
)
]
return
text
*
num_images
,
token_ids
*
num_images
def
_get_image_repl_features
(
self
,
num_images
:
int
,
num_patches
:
int
,
patch_new_line_idx
:
Optional
[
list
[
bool
]],
)
->
tuple
[
str
,
list
[
int
]]:
if
num_patches
>
0
:
patch_repl
,
patch_repl_ids
=
self
.
_get_patch_repl
(
num_patches
,
patch_new_line_idx
)
else
:
patch_repl
=
""
patch_repl_ids
=
[]
image_repl
,
image_repl_ids
=
self
.
_get_image_repl
(
num_images
)
return
patch_repl
+
image_repl
,
patch_repl_ids
+
image_repl_ids
def
replace_placeholder
(
self
,
text
:
str
,
placeholder
:
str
,
repls
:
list
[
str
])
->
str
:
parts
=
text
.
split
(
placeholder
)
if
len
(
parts
)
-
1
!=
len
(
repls
):
raise
ValueError
(
"The number of placeholders does not match the number of replacements."
# noqa: E501
)
result
=
[
parts
[
0
]]
for
i
,
repl
in
enumerate
(
repls
):
result
.
append
(
repl
)
result
.
append
(
parts
[
i
+
1
])
return
""
.
join
(
result
)
def
__call__
(
self
,
text
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
images
:
Optional
[
Union
[
Image
.
Image
,
list
[
Image
.
Image
]]]
=
None
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
)
->
BatchFeature
:
if
text
is
None
:
text
=
[]
if
not
isinstance
(
text
,
list
):
text
=
[
text
]
if
images
is
None
:
images
=
[]
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
if
len
(
images
)
==
0
:
image_inputs
=
{}
text_inputs
=
self
.
tokenizer
(
text
)
else
:
splitted_images_data
=
self
.
_split_images
(
images
)
pixel_values_lst
=
[]
patch_pixel_values_lst
=
[]
patch_newline_mask_lst
=
[]
image_repl_str_lst
=
[]
image_repl_ids_lst
=
[]
num_patches
=
[]
for
raw_img
,
img_patches
,
patch_newline_mask
in
splitted_images_data
:
# noqa: E501
pixel_values_lst
.
extend
(
self
.
_convert_images_to_pixel_values
([
raw_img
]))
if
len
(
img_patches
)
>
0
:
patch_pixel_values_lst
.
extend
(
self
.
_convert_images_to_pixel_values
(
img_patches
,
is_patch
=
True
))
num_patches
.
append
(
len
(
img_patches
))
image_repl_str
,
image_repl_ids
=
self
.
_get_image_repl_features
(
1
,
len
(
img_patches
),
patch_newline_mask
)
image_repl_str_lst
.
append
(
image_repl_str
)
image_repl_ids_lst
.
extend
(
image_repl_ids
)
if
patch_newline_mask
is
not
None
:
patch_newline_mask_lst
.
extend
(
patch_newline_mask
)
image_inputs
=
{
"pixel_values"
:
torch
.
cat
(
pixel_values_lst
),
"num_patches"
:
num_patches
,
}
if
patch_pixel_values_lst
:
image_inputs
[
"patch_pixel_values"
]
=
torch
.
cat
(
patch_pixel_values_lst
)
if
patch_newline_mask_lst
:
image_inputs
[
"patch_newline_mask"
]
=
torch
.
tensor
(
patch_newline_mask_lst
,
dtype
=
torch
.
bool
)
text
=
[
self
.
replace_placeholder
(
t
,
self
.
image_token
,
image_repl_str_lst
)
for
t
in
text
]
text_inputs
=
self
.
tokenizer
(
text
)
return
BatchFeature
(
{
**
text_inputs
,
**
image_inputs
,
},
tensor_type
=
return_tensors
,
)
class
Step3VLProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_processor
(
self
)
->
Step3VLProcessor
:
return
Step3VLProcessor
(
self
.
get_hf_config
(),
self
.
get_tokenizer
(),
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
get_max_image_tokens
(
self
)
->
int
:
hf_processor
=
self
.
get_hf_processor
()
return
hf_processor
.
get_num_image_tokens
(
self
.
get_image_size_with_most_features
().
width
,
self
.
get_image_size_with_most_features
().
height
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
return
ImageSize
(
3024
,
3024
)
def
get_num_mm_tokens
(
self
,
mm_data
:
MultiModalDataDict
)
->
int
:
if
len
(
mm_data
)
!=
1
or
"image"
not
in
mm_data
:
raise
ValueError
(
"mm_data could only contain one key 'image' for steo1o"
)
image_data
=
mm_data
[
"image"
]
if
not
isinstance
(
image_data
,
(
list
,
tuple
)):
image_data
=
[
image_data
]
return
sum
(
self
.
get_hf_processor
().
get_num_image_tokens
(
img
.
width
,
img
.
height
)
for
img
in
image_data
)
class
Step3VLDummyInputsBuilder
(
BaseDummyInputsBuilder
[
Step3VLProcessingInfo
]):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
return
"<im_patch>"
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
class
Step3VLMultiModalProcessor
(
BaseMultiModalProcessor
[
Step3VLProcessingInfo
]
):
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_placeholder_token_id
=
hf_processor
.
image_token_id
batch_num_patches
=
out_mm_kwargs
[
"num_patches"
].
tolist
()
def
get_replacement_step1o
(
item_idx
:
int
):
img_out
=
out_mm_kwargs
.
get_item
(
"image"
,
item_idx
)
num_patches
=
batch_num_patches
[
item_idx
]
if
num_patches
>
0
:
patch_newline_mask
=
img_out
[
"patch_newline_mask"
].
data
.
tolist
(
)
image_repl_ids
=
hf_processor
.
_get_image_repl_features
(
1
,
num_patches
,
patch_newline_mask
)[
1
]
else
:
image_repl_ids
=
hf_processor
.
_get_image_repl_features
(
1
,
0
,
None
)[
1
]
return
PromptUpdateDetails
.
select_token_id
(
seq
=
image_repl_ids
,
embed_token_id
=
image_placeholder_token_id
,
)
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_placeholder_token_id
],
replacement
=
get_replacement_step1o
,
)
]
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_patches
=
hf_inputs
.
get
(
"num_patches"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
patch_pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
patch_newline_mask
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
)
def
get_abs_pos
(
abs_pos
,
tgt_size
):
dim
=
abs_pos
.
size
(
-
1
)
abs_pos_new
=
abs_pos
.
squeeze
(
0
)
cls_token
,
old_pos_embed
=
abs_pos_new
[:
1
],
abs_pos_new
[
1
:]
src_size
=
int
(
math
.
sqrt
(
abs_pos_new
.
shape
[
0
]
-
1
))
tgt_size
=
int
(
math
.
sqrt
(
tgt_size
))
dtype
=
abs_pos
.
dtype
if
src_size
!=
tgt_size
:
old_pos_embed
=
old_pos_embed
.
view
(
1
,
src_size
,
src_size
,
dim
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
old_pos_embed
=
old_pos_embed
.
to
(
torch
.
float32
)
new_pos_embed
=
F
.
interpolate
(
old_pos_embed
,
size
=
(
tgt_size
,
tgt_size
),
mode
=
'bicubic'
,
antialias
=
True
,
align_corners
=
False
,
).
to
(
dtype
)
new_pos_embed
=
new_pos_embed
.
permute
(
0
,
2
,
3
,
1
)
new_pos_embed
=
new_pos_embed
.
view
(
tgt_size
*
tgt_size
,
dim
)
vision_pos_embed
=
torch
.
cat
([
cls_token
,
new_pos_embed
],
dim
=
0
)
vision_pos_embed
=
vision_pos_embed
.
view
(
1
,
tgt_size
*
tgt_size
+
1
,
dim
)
return
vision_pos_embed
else
:
return
abs_pos
class
Step3VisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Step3VisionEncoderConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
True
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
pad_tp_size
=
4
# hard code for padding
# To load the pretrained weights, we still use P+1 as the seqlen
self
.
position_embedding
=
torch
.
nn
.
Embedding
(
self
.
num_patches
+
1
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_patches
+
1
).
expand
(
(
1
,
-
1
)),
persistent
=
False
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
# pad
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
get_abs_pos
(
self
.
position_embedding
(
self
.
position_ids
),
patch_embeds
.
size
(
1
))
embeddings
=
torch
.
cat
([
embeddings
[:,
0
,
:].
unsqueeze
(
1
).
repeat
(
1
,
self
.
pad_tp_size
-
1
,
1
),
embeddings
],
dim
=
1
)
return
embeddings
class
Step3VisionAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
total_num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
out_proj
=
RowParallelLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
):
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
_
=
hidden_states
.
size
()
# get query proj
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
=
q
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
self
.
head_dim
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
attn_output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
scale
=
self
.
scale
,
is_causal
=
False
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
bsz
,
tgt_len
,
self
.
num_heads
*
self
.
head_dim
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
class
Step3VisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
prefix
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
Step3VisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Step3VisionEncoderConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
Step3VisionAttention
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
Step3VisionMLP
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
FloatTensor
:
hidden_states
=
hidden_states
+
self
.
layer_norm1
(
self
.
self_attn
(
hidden_states
))
hidden_states
=
hidden_states
+
self
.
layer_norm2
(
self
.
mlp
(
hidden_states
))
return
hidden_states
class
Step3VisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Step3VisionEncoderConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
Step3VisionEncoderLayer
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
i
}
"
)
for
i
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
inputs_embeds
,
):
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
)
return
hidden_states
class
Step3VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Step3VisionEncoderConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
config
self
.
image_size
=
config
.
image_size
self
.
embeddings
=
Step3VisionEmbeddings
(
config
)
self
.
transformer
=
Step3VisionEncoder
(
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.transformer"
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
):
hidden_states
=
self
.
embeddings
(
pixel_values
)
hidden_states
=
self
.
transformer
(
inputs_embeds
=
hidden_states
)
return
hidden_states
@
MULTIMODAL_REGISTRY
.
register_processor
(
Step3VLMultiModalProcessor
,
info
=
Step3VLProcessingInfo
,
dummy_inputs
=
Step3VLDummyInputsBuilder
)
class
Step3VLForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
"language_model.model."
,
"lm_head."
:
"language_model.lm_head."
,
})
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
if
modality
.
startswith
(
"image"
):
return
"<im_patch>"
raise
ValueError
(
"Only image modality is supported"
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
vision_model
=
Step3VisionTransformer
(
config
.
vision_config
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
))
self
.
vit_downsampler
=
nn
.
Conv2d
(
config
.
vision_config
.
hidden_size
,
config
.
vision_config
.
output_hidden_size
,
kernel_size
=
2
,
stride
=
config
.
understand_projector_stride
)
self
.
vit_downsampler2
=
nn
.
Conv2d
(
config
.
vision_config
.
output_hidden_size
,
config
.
vision_config
.
output_hidden_size
*
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
)
self
.
vit_large_projector
=
nn
.
Linear
(
config
.
vision_config
.
output_hidden_size
*
2
,
config
.
hidden_size
,
bias
=
config
.
projector_bias
,
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
@
property
def
device
(
self
):
return
next
(
self
.
parameters
()).
device
@
property
def
dtype
(
self
):
return
next
(
self
.
parameters
()).
dtype
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Step3VLImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
patch_pixel_values
=
kwargs
.
pop
(
"patch_pixel_values"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
if
pixel_values
.
dim
()
>=
3
:
pixel_values
=
pixel_values
.
view
(
-
1
,
*
pixel_values
.
shape
[
-
3
:])
if
patch_pixel_values
is
not
None
:
patch_pixel_values
=
flatten_bn
(
patch_pixel_values
,
concat
=
True
)
patch_pixel_values
=
patch_pixel_values
.
view
(
-
1
,
*
patch_pixel_values
.
shape
[
-
3
:])
# Handle empty patch_pixel_values by setting to None
if
patch_pixel_values
.
shape
[
0
]
==
0
:
patch_pixel_values
=
None
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
).
tolist
()
return
Step3VLImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
.
to
(
self
.
dtype
).
to
(
self
.
device
),
patch_pixel_values
=
patch_pixel_values
.
to
(
self
.
dtype
).
to
(
self
.
device
)
if
patch_pixel_values
is
not
None
else
None
,
num_patches
=
num_patches
,
)
if
image_embeds
is
not
None
:
if
image_embeds
.
dim
()
==
2
or
image_embeds
.
dim
()
>=
3
:
image_embeds
=
image_embeds
.
view
(
-
1
,
image_embeds
.
shape
[
-
1
])
else
:
raise
ValueError
(
f
"Unexpected shape for image_embeds:
{
image_embeds
.
shape
}
"
)
return
Step3VLImageEmbeddingInputs
(
type
=
"image_embeds"
,
image_embeds
=
image_embeds
.
to
(
self
.
dtype
).
to
(
self
.
device
),
)
return
None
def
_process_image_features
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
P
=
image_features
.
shape
[:
2
]
HW
=
int
(
sqrt
(
P
))
image_features
=
image_features
.
permute
(
0
,
2
,
1
).
view
(
B
,
-
1
,
HW
,
HW
)
image_features
=
self
.
vit_downsampler
(
image_features
)
image_features
=
self
.
vit_downsampler2
(
image_features
)
n_dim
=
image_features
.
size
(
1
)
image_features
=
image_features
.
view
(
B
,
n_dim
,
-
1
).
permute
(
0
,
2
,
1
)
image_features
=
self
.
vit_large_projector
(
image_features
)
return
image_features
def
_get_vision_model_output
(
self
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
vision_model
(
input_tensor
)[:,
4
:]
def
_process_image_input
(
self
,
image_input
:
Step3VLImageInputs
)
->
tuple
[
torch
.
Tensor
,
...]:
if
image_input
[
"type"
]
==
"image_embeds"
:
image_features
=
image_input
[
"image_embeds"
]
else
:
image_features
=
self
.
_get_vision_model_output
(
image_input
[
"pixel_values"
])
patch_image_features
=
self
.
_get_vision_model_output
(
image_input
[
"patch_pixel_values"
]
)
if
image_input
[
"patch_pixel_values"
]
is
not
None
else
None
num_patches
=
image_input
[
"num_patches"
]
image_features
=
self
.
_process_image_features
(
image_features
)
patch_image_features
=
self
.
_process_image_features
(
patch_image_features
)
if
patch_image_features
is
not
None
else
None
merged_image_features
=
[]
cur_patch_idx
=
0
for
i
,
num_patch
in
enumerate
(
num_patches
):
cur_feature
=
[]
if
num_patch
>
0
:
patch_slice
=
patch_image_features
[
cur_patch_idx
:
cur_patch_idx
+
num_patch
]
cur_feature
.
append
(
patch_slice
.
view
(
-
1
,
patch_slice
.
shape
[
-
1
]))
cur_feature
.
append
(
image_features
[
i
].
view
(
-
1
,
image_features
.
shape
[
-
1
]))
cur_patch_idx
+=
num_patch
merged_image_features
.
append
(
torch
.
cat
(
cur_feature
)
if
len
(
cur_feature
)
>
1
else
cur_feature
[
0
])
return
merged_image_features
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
if
multimodal_embeddings
is
None
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
else
:
is_text
=
input_ids
!=
self
.
config
.
image_token_id
text_ids
=
input_ids
[
is_text
]
text_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
text_ids
)
inputs_embeds
=
torch
.
empty
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
-
1
],
dtype
=
text_embeds
.
dtype
,
device
=
text_embeds
.
device
)
inputs_embeds
[
is_text
]
=
text_embeds
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
config
.
image_token_id
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
elif
inputs_embeds
is
None
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
input_ids
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
)
loaded_weights
=
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
return
loaded_weights
\ No newline at end of file
vllm/reasoning/__init__.py
View file @
a8c6fcf6
...
@@ -6,6 +6,7 @@ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
...
@@ -6,6 +6,7 @@ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from
.glm4_moe_reasoning_parser
import
Glm4MoeModelReasoningParser
from
.glm4_moe_reasoning_parser
import
Glm4MoeModelReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
from
.qwen3_reasoning_parser
import
Qwen3ReasoningParser
from
.qwen3_reasoning_parser
import
Qwen3ReasoningParser
from
.step3_reasoning_parser
import
Step3ReasoningParser
__all__
=
[
__all__
=
[
"ReasoningParser"
,
"ReasoningParser"
,
...
@@ -14,4 +15,5 @@ __all__ = [
...
@@ -14,4 +15,5 @@ __all__ = [
"GraniteReasoningParser"
,
"GraniteReasoningParser"
,
"Qwen3ReasoningParser"
,
"Qwen3ReasoningParser"
,
"Glm4MoeModelReasoningParser"
,
"Glm4MoeModelReasoningParser"
,
"Step3ReasoningParser"
,
]
]
vllm/reasoning/step3_reasoning_parser.py
0 → 100644
View file @
a8c6fcf6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
logger
=
init_logger
(
__name__
)
@
ReasoningParserManager
.
register_module
(
"step3"
)
class
Step3ReasoningParser
(
ReasoningParser
):
"""
Reasoning parser for Step3 model.
The Step3 model uses </think> token to denote the end of reasoning
text. This parser extracts all content before </think> as reasoning content.
"""
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
self
.
think_end_token
=
"</think>"
self
.
reasoning_regex
=
re
.
compile
(
rf
"(.*?)
{
self
.
think_end_token
}
"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
self
.
think_end_token_id
=
self
.
vocab
.
get
(
self
.
think_end_token
)
if
self
.
think_end_token_id
is
None
:
raise
RuntimeError
(
"Step3 reasoning parser could not locate think end "
"token in the tokenizer!"
)
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text "abc</think>xyz":
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special token
if
len
(
delta_token_ids
)
==
1
and
delta_token_ids
[
0
]
==
self
.
think_end_token_id
:
return
None
if
self
.
think_end_token_id
in
delta_token_ids
:
# </think> in delta, extract reasoning content and remaining content
end_index
=
delta_text
.
find
(
self
.
think_end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
# </think> already seen in previous text, everything is content
return
DeltaMessage
(
content
=
delta_text
)
else
:
# No </think> seen yet, everything is reasoning
return
DeltaMessage
(
reasoning_content
=
delta_text
)
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
# Check if the model output contains the </think> token
if
self
.
think_end_token
not
in
model_output
:
# If no </think> token, everything is reasoning content
return
model_output
,
None
else
:
# Find the first occurrence of </think>
end_index
=
model_output
.
find
(
self
.
think_end_token
)
reasoning_content
=
model_output
[:
end_index
]
# Content after </think> token
content
=
model_output
[
end_index
+
len
(
self
.
think_end_token
):]
if
len
(
content
)
==
0
:
content
=
None
return
reasoning_content
,
content
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
return
self
.
think_end_token_id
in
input_ids
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
if
self
.
think_end_token_id
not
in
input_ids
[:
-
1
]:
return
[]
else
:
return
input_ids
[
input_ids
.
index
(
self
.
think_end_token_id
)
+
1
:]
\ No newline at end of file
vllm/transformers_utils/config.py
View file @
a8c6fcf6
...
@@ -39,6 +39,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
...
@@ -39,6 +39,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MLPSpeculatorConfig
,
MPTConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
NemotronConfig
,
NVLM_D_Config
,
NemotronConfig
,
NVLM_D_Config
,
OvisConfig
,
RWConfig
,
OvisConfig
,
RWConfig
,
Step3TextConfig
,
Step3VLConfig
,
SkyworkR1VChatConfig
,
SolarConfig
,
SkyworkR1VChatConfig
,
SolarConfig
,
Telechat2Config
,
UltravoxConfig
)
Telechat2Config
,
UltravoxConfig
)
# yapf: enable
# yapf: enable
...
@@ -97,6 +98,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
...
@@ -97,6 +98,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"skywork_chat"
:
SkyworkR1VChatConfig
,
"skywork_chat"
:
SkyworkR1VChatConfig
,
"telechat"
:
Telechat2Config
,
"telechat"
:
Telechat2Config
,
"ultravox"
:
UltravoxConfig
,
"ultravox"
:
UltravoxConfig
,
"step3_vl"
:
Step3VLConfig
,
"step3_text"
:
Step3TextConfig
,
**
_CONFIG_REGISTRY_OVERRIDE_HF
**
_CONFIG_REGISTRY_OVERRIDE_HF
}
}
...
...
vllm/transformers_utils/configs/__init__.py
View file @
a8c6fcf6
...
@@ -28,6 +28,9 @@ from vllm.transformers_utils.configs.ovis import OvisConfig
...
@@ -28,6 +28,9 @@ from vllm.transformers_utils.configs.ovis import OvisConfig
from
vllm.transformers_utils.configs.skyworkr1v
import
SkyworkR1VChatConfig
from
vllm.transformers_utils.configs.skyworkr1v
import
SkyworkR1VChatConfig
from
vllm.transformers_utils.configs.solar
import
SolarConfig
from
vllm.transformers_utils.configs.solar
import
SolarConfig
from
vllm.transformers_utils.configs.telechat2
import
Telechat2Config
from
vllm.transformers_utils.configs.telechat2
import
Telechat2Config
from
vllm.transformers_utils.configs.step3_vl
import
(
Step3TextConfig
,
Step3VisionEncoderConfig
,
Step3VLConfig
)
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
__all__
=
[
__all__
=
[
...
@@ -56,4 +59,7 @@ __all__ = [
...
@@ -56,4 +59,7 @@ __all__ = [
"SolarConfig"
,
"SolarConfig"
,
"Telechat2Config"
,
"Telechat2Config"
,
"UltravoxConfig"
,
"UltravoxConfig"
,
"Step3VLConfig"
,
"Step3VisionEncoderConfig"
,
"Step3TextConfig"
,
]
]
vllm/transformers_utils/configs/step3_vl.py
0 → 100644
View file @
a8c6fcf6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
,
Union
from
transformers.configuration_utils
import
PretrainedConfig
class
Step3VisionEncoderConfig
(
PretrainedConfig
):
model_type
=
"step3_vision_encoder"
def
__init__
(
self
,
hidden_size
=
1792
,
intermediate_size
=
3072
,
output_hidden_size
=
4096
,
num_hidden_layers
=
63
,
num_attention_heads
=
16
,
num_channels
=
3
,
image_size
=
728
,
patch_size
=
14
,
hidden_act
=
"quick_gelu"
,
layer_norm_eps
=
1e-5
,
**
kwargs
,
):
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
output_hidden_size
=
output_hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_channels
=
num_channels
self
.
patch_size
=
patch_size
self
.
image_size
=
image_size
self
.
layer_norm_eps
=
layer_norm_eps
self
.
hidden_act
=
hidden_act
super
().
__init__
(
**
kwargs
)
class
Step3TextConfig
(
PretrainedConfig
):
model_type
=
"step3_text"
architectures
=
[
"Step3TextForCausalLM"
]
def
__init__
(
self
,
hidden_size
:
int
=
7168
,
intermediate_size
:
int
=
18432
,
num_attention_heads
:
int
=
64
,
num_attention_groups
:
int
=
1
,
num_hidden_layers
:
int
=
61
,
max_seq_len
:
int
=
65536
,
vocab_size
:
int
=
128815
,
rms_norm_eps
:
float
=
1e-5
,
moe_intermediate_size
:
int
=
5120
,
moe_num_experts
:
int
=
48
,
moe_top_k
:
int
=
3
,
rope_theta
:
float
=
500000
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
max_position_embedding
:
int
=
65536
,
share_expert_dim
:
int
=
5120
,
share_q_dim
:
int
=
2048
,
head_dim
:
int
=
256
,
norm_expert_weight
:
bool
=
False
,
moe_layers_enum
:
tuple
[
int
,
...]
=
(
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
,
48
,
49
,
50
,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
),
**
kwargs
,
)
->
None
:
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_num_experts
=
moe_num_experts
self
.
moe_top_k
=
moe_top_k
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
max_position_embedding
=
max_position_embedding
self
.
share_expert_dim
=
share_expert_dim
self
.
share_q_dim
=
share_q_dim
self
.
head_dim
=
head_dim
self
.
norm_expert_weight
=
norm_expert_weight
self
.
moe_layers_enum
=
moe_layers_enum
super
().
__init__
(
**
kwargs
)
class
Step3VLConfig
(
PretrainedConfig
):
model_type
=
"step3_vl"
def
__init__
(
self
,
vision_config
:
Optional
[
Union
[
dict
,
Step3VisionEncoderConfig
]]
=
None
,
text_config
:
Optional
[
Union
[
dict
,
Step3TextConfig
]]
=
None
,
understand_projector_stride
:
int
=
1
,
projector_bias
:
bool
=
True
,
image_token_id
:
int
=
128001
,
**
kwargs
,
)
->
None
:
if
vision_config
is
None
:
vision_config
=
Step3VisionEncoderConfig
()
elif
isinstance
(
vision_config
,
dict
):
vision_config
=
Step3VisionEncoderConfig
(
**
vision_config
)
self
.
vision_config
=
vision_config
if
text_config
is
None
:
text_config
=
Step3TextConfig
()
elif
isinstance
(
text_config
,
dict
):
text_config
=
Step3TextConfig
(
**
text_config
)
self
.
text_config
=
text_config
self
.
understand_projector_stride
=
understand_projector_stride
self
.
projector_bias
=
projector_bias
self
.
hidden_size
=
text_config
.
hidden_size
self
.
image_token_id
=
image_token_id
super
().
__init__
(
**
kwargs
)
\ No newline at end of file
vllm/v1/attention/backends/mla/common.py
View file @
a8c6fcf6
...
@@ -690,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -690,7 +690,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def
can_run_in_cudagraph
(
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
#return common_attn_metadata.max_query_len == 1
if
not
self
.
use_spec_decode
:
return
common_attn_metadata
.
max_query_len
==
1
return
self
.
_num_prefills
==
0
return
self
.
_num_prefills
==
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment