Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
39052dbc
Unverified
Commit
39052dbc
authored
Aug 11, 2025
by
Maximilien de Bayser
Committed by
GitHub
Aug 10, 2025
Browse files
Support token_type_ids in V1 with less code changes (#21985)
Signed-off-by:
Max de Bayser
<
mbayser@br.ibm.com
>
parent
9c97a1c3
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
235 additions
and
130 deletions
+235
-130
tests/entrypoints/openai/test_rerank.py
tests/entrypoints/openai/test_rerank.py
+3
-1
tests/entrypoints/openai/test_score.py
tests/entrypoints/openai/test_score.py
+3
-1
tests/models/language/pooling/test_scoring.py
tests/models/language/pooling/test_scoring.py
+9
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+26
-28
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+31
-51
vllm/entrypoints/score_utils.py
vllm/entrypoints/score_utils.py
+37
-3
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+63
-25
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+17
-19
vllm/pooling_params.py
vllm/pooling_params.py
+6
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+40
-0
No files found.
tests/entrypoints/openai/test_rerank.py
View file @
39052dbc
...
@@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer):
...
@@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer):
invocation_output
[
"results"
]):
invocation_output
[
"results"
]):
assert
rerank_result
.
keys
()
==
invocations_result
.
keys
()
assert
rerank_result
.
keys
()
==
invocations_result
.
keys
()
assert
rerank_result
[
"relevance_score"
]
==
pytest
.
approx
(
assert
rerank_result
[
"relevance_score"
]
==
pytest
.
approx
(
invocations_result
[
"relevance_score"
],
rel
=
0.01
)
invocations_result
[
"relevance_score"
],
rel
=
0.05
)
# TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
...
tests/entrypoints/openai/test_score.py
View file @
39052dbc
...
@@ -220,7 +220,9 @@ class TestModel:
...
@@ -220,7 +220,9 @@ class TestModel:
invocation_output
[
"data"
]):
invocation_output
[
"data"
]):
assert
score_data
.
keys
()
==
invocation_data
.
keys
()
assert
score_data
.
keys
()
==
invocation_data
.
keys
()
assert
score_data
[
"score"
]
==
pytest
.
approx
(
assert
score_data
[
"score"
]
==
pytest
.
approx
(
invocation_data
[
"score"
],
rel
=
0.01
)
invocation_data
[
"score"
],
rel
=
0.05
)
# TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16
def
test_activation
(
self
,
server
:
RemoteOpenAIServer
,
model
:
dict
[
str
,
def
test_activation
(
self
,
server
:
RemoteOpenAIServer
,
model
:
dict
[
str
,
Any
]):
Any
]):
...
...
tests/models/language/pooling/test_scoring.py
View file @
39052dbc
...
@@ -23,6 +23,15 @@ TEXTS_2 = [
...
@@ -23,6 +23,15 @@ TEXTS_2 = [
"The capital of Germany is Berlin."
,
"The capital of Germany is Berlin."
,
]
]
@
pytest
.
fixture
(
autouse
=
True
)
def
v1
(
run_with_both_engines
):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
DTYPE
=
"half"
DTYPE
=
"half"
...
...
vllm/entrypoints/llm.py
View file @
39052dbc
...
@@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
...
@@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
,
parse_chat_messages
,
resolve_chat_template_content_format
)
resolve_chat_template_content_format
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.score_utils
import
(
ScoreContentPartParam
,
from
vllm.entrypoints.score_utils
import
(
ScoreContentPartParam
,
ScoreMultiModalParam
,
ScoreMultiModalParam
,
_cosine_similarity
,
_cosine_similarity
,
_validate_score_input_lens
,
_validate_score_input_lens
,
compress_token_type_ids
,
get_score_prompt
)
get_score_prompt
)
# yapf: enable
from
vllm.entrypoints.utils
import
(
_validate_truncation_size
,
from
vllm.entrypoints.utils
import
(
_validate_truncation_size
,
log_non_default_args
)
log_non_default_args
)
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
...
@@ -1329,6 +1333,7 @@ class LLM:
...
@@ -1329,6 +1333,7 @@ class LLM:
model_config
=
self
.
llm_engine
.
model_config
model_config
=
self
.
llm_engine
.
model_config
pooling_params
.
verify
(
"score"
,
model_config
)
pooling_params
.
verify
(
"score"
,
model_config
)
pooling_params_list
=
list
[
PoolingParams
]()
tokenization_kwargs
:
dict
[
str
,
Any
]
=
{}
tokenization_kwargs
:
dict
[
str
,
Any
]
=
{}
...
@@ -1339,38 +1344,31 @@ class LLM:
...
@@ -1339,38 +1344,31 @@ class LLM:
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
if
model_config
.
is_multimodal_model
:
model_config
=
self
.
llm_engine
.
model_config
for
q
,
d
in
input_pairs
:
_
,
engine_prompt
=
get_score_prompt
(
model_config
=
model_config
,
data_1
=
q
,
data_2
=
d
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
)
parsed_prompts
.
append
(
engine_prompt
)
for
q
,
d
in
input_pairs
:
else
:
_
,
engine_prompt
=
get_score_prompt
(
for
q
,
t
in
input_pairs
:
model_config
=
model_config
,
if
model_config
.
use_pad_token
:
data_1
=
q
,
# cross_encoder models defaults to using pad_token.
data_2
=
d
,
prompt_inputs
=
tokenizer
(
tokenizer
=
tokenizer
,
text
=
q
,
# type: ignore[arg-type]
tokenization_kwargs
=
tokenization_kwargs
,
text_pair
=
t
,
# type: ignore[arg-type]
)
**
tokenization_kwargs
)
else
:
if
envs
.
VLLM_USE_V1
and
(
token_type_ids
:
=
engine_prompt
.
pop
(
# `llm as reranker` models defaults to not using pad_token.
"token_type_ids"
,
None
)):
prompt_inputs
=
tokenizer
(
params
=
pooling_params
.
clone
()
text
=
q
+
t
,
# type: ignore[operator]
compressed
=
compress_token_type_ids
(
token_type_ids
)
**
tokenization_kwargs
)
params
.
extra_kwargs
=
{
"compressed_token_type_ids"
:
compressed
}
engine_prompt
=
TokensPrompt
(
pooling_params_list
.
append
(
params
)
prompt_token_ids
=
prompt_inputs
[
"input_ids"
],
else
:
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
pooling_params_list
.
append
(
pooling_params
)
parsed_prompts
.
append
(
engine_prompt
)
parsed_prompts
.
append
(
engine_prompt
)
self
.
_validate_and_add_requests
(
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
prompts
=
parsed_prompts
,
params
=
pooling_params
,
params
=
pooling_params
_list
,
use_tqdm
=
use_tqdm
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
...
...
vllm/entrypoints/openai/serving_score.py
View file @
39052dbc
...
@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
...
@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm
import
envs
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
...
@@ -17,11 +18,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
...
@@ -17,11 +18,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
ScoreResponseData
,
UsageInfo
)
ScoreResponseData
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.score_utils
import
(
ScoreContentPartParam
,
from
vllm.entrypoints.score_utils
import
(
ScoreContentPartParam
,
ScoreMultiModalParam
,
ScoreMultiModalParam
,
_cosine_similarity
,
_cosine_similarity
,
_validate_score_input_lens
,
_validate_score_input_lens
,
compress_token_type_ids
,
get_score_prompt
)
get_score_prompt
)
# yapf: enable
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -158,6 +163,8 @@ class ServingScores(OpenAIServing):
...
@@ -158,6 +163,8 @@ class ServingScores(OpenAIServing):
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
self
.
_validate_input
(
request
,
engine_prompt
[
"prompt_token_ids"
],
full_prompt
)
if
request
.
mm_processor_kwargs
is
not
None
:
if
request
.
mm_processor_kwargs
is
not
None
:
engine_prompt
[
"mm_processor_kwargs"
]
=
request
.
mm_processor_kwargs
engine_prompt
[
"mm_processor_kwargs"
]
=
request
.
mm_processor_kwargs
...
@@ -188,64 +195,27 @@ class ServingScores(OpenAIServing):
...
@@ -188,64 +195,27 @@ class ServingScores(OpenAIServing):
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
if
self
.
model_config
.
is_multimodal_model
:
preprocess_async
=
make_async
(
self
.
_preprocess_score
,
executor
=
self
.
_tokenizer_executor
)
preprocess_async
=
make_async
(
self
.
_preprocess_score
,
preprocessed_prompts
=
await
asyncio
.
gather
(
executor
=
self
.
_tokenizer_executor
)
*
(
preprocess_async
(
request
=
request
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
data_1
=
t1
,
data_2
=
t2
)
for
t1
,
t2
in
input_pairs
))
preprocessed_prompts
=
await
asyncio
.
gather
(
for
full_prompt
,
engine_prompt
in
preprocessed_prompts
:
*
(
preprocess_async
(
request
=
request
,
request_prompts
.
append
(
full_prompt
)
tokenizer
=
tokenizer
,
engine_prompts
.
append
(
engine_prompt
)
tokenization_kwargs
=
tokenization_kwargs
,
data_1
=
t1
,
data_2
=
t2
)
for
t1
,
t2
in
input_pairs
))
for
full_prompt
,
engine_prompt
in
preprocessed_prompts
:
request_prompts
.
append
(
full_prompt
)
engine_prompts
.
append
(
engine_prompt
)
else
:
tokenize_async
=
make_async
(
tokenizer
.
__call__
,
executor
=
self
.
_tokenizer_executor
)
use_pad_token
=
self
.
model_config
.
use_pad_token
if
use_pad_token
:
# cross_encoder models defaults to using pad_token.
tokenized_prompts
=
await
asyncio
.
gather
(
*
(
tokenize_async
(
text
=
t1
,
# type: ignore[arg-type]
text_pair
=
t2
,
# type: ignore[arg-type]
**
tokenization_kwargs
)
for
t1
,
t2
in
input_pairs
))
else
:
# `llm as reranker` models defaults to not using pad_token.
tokenized_prompts
=
await
asyncio
.
gather
(
*
(
tokenize_async
(
text
=
t1
+
# type: ignore[operator]
t2
,
**
tokenization_kwargs
)
for
t1
,
t2
in
input_pairs
))
for
prompt_inputs
,
(
t1
,
t2
)
in
zip
(
tokenized_prompts
,
input_pairs
):
sep_token
=
tokenizer
.
sep_token
if
(
tokenizer
.
sep_token
and
use_pad_token
)
else
''
request_prompt
=
f
"
{
t1
}{
sep_token
}{
t2
}
"
input_ids
=
prompt_inputs
[
"input_ids"
]
text_token_prompt
=
\
self
.
_validate_input
(
request
,
input_ids
,
request_prompt
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
text_token_prompt
[
"prompt_token_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
request_prompts
.
append
(
request_prompt
)
engine_prompts
.
append
(
engine_prompt
)
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
pooling_params
=
request
.
to_pooling_params
()
default_
pooling_params
=
request
.
to_pooling_params
()
try
:
try
:
pooling_params
.
verify
(
"score"
,
self
.
model_config
)
default_
pooling_params
.
verify
(
"score"
,
self
.
model_config
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -254,9 +224,19 @@ class ServingScores(OpenAIServing):
...
@@ -254,9 +224,19 @@ class ServingScores(OpenAIServing):
self
.
_log_inputs
(
request_id_item
,
self
.
_log_inputs
(
request_id_item
,
request_prompts
[
i
],
request_prompts
[
i
],
params
=
pooling_params
,
params
=
default_
pooling_params
,
lora_request
=
lora_request
)
lora_request
=
lora_request
)
if
envs
.
VLLM_USE_V1
and
(
token_type_ids
:
=
engine_prompt
.
pop
(
"token_type_ids"
,
None
)):
pooling_params
=
default_pooling_params
.
clone
()
compressed
=
compress_token_type_ids
(
token_type_ids
)
pooling_params
.
extra_kwargs
=
{
"compressed_token_type_ids"
:
compressed
}
else
:
pooling_params
=
(
default_pooling_params
)
generator
=
self
.
engine_client
.
encode
(
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
engine_prompt
,
pooling_params
,
pooling_params
,
...
...
vllm/entrypoints/score_utils.py
View file @
39052dbc
...
@@ -184,15 +184,49 @@ def get_score_prompt(
...
@@ -184,15 +184,49 @@ def get_score_prompt(
model_config
,
model_config
,
tokenizer
,
tokenizer
,
)
)
from
vllm.model_executor.model_loader
import
get_model_cls
full_prompt
=
apply_score_template
(
model_config
,
prompt_1
,
prompt_2
)
model
=
get_model_cls
(
model_config
)
if
supports_score_template
(
model
):
prompt_inputs
=
tokenizer
(
full_prompt
,
**
tokenization_kwargs
)
full_prompt
=
apply_score_template
(
model_config
,
prompt_1
,
prompt_2
)
prompt_inputs
=
tokenizer
(
full_prompt
,
**
tokenization_kwargs
)
elif
model_config
.
use_pad_token
:
# cross_encoder models defaults to using pad_token.
prompt_inputs
=
tokenizer
(
text
=
prompt_1
,
text_pair
=
prompt_2
,
**
tokenization_kwargs
)
full_prompt
=
tokenizer
.
decode
(
prompt_inputs
[
"input_ids"
])
else
:
# `llm as reranker` models defaults to not using pad_token.
full_prompt
=
prompt_1
+
prompt_2
prompt_inputs
=
tokenizer
(
text
=
full_prompt
,
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
])
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
])
if
(
token_type_ids
:
=
prompt_inputs
.
get
(
"token_type_ids"
))
is
not
None
:
engine_prompt
[
"token_type_ids"
]
=
token_type_ids
post_process_tokens
(
model_config
,
engine_prompt
)
post_process_tokens
(
model_config
,
engine_prompt
)
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
engine_prompt
[
"multi_modal_data"
]
=
mm_data
engine_prompt
[
"multi_modal_data"
]
=
mm_data
return
full_prompt
,
engine_prompt
return
full_prompt
,
engine_prompt
def
compress_token_type_ids
(
token_type_ids
:
list
[
int
])
->
int
:
"""
Return position of the first 1 or the length of the list
if not found.
"""
first_one
=
len
(
token_type_ids
)
err_msg
=
"Token type ids are expected to be a sequence"
\
" of zeros followed by a sequence of ones"
for
i
,
type_id
in
enumerate
(
token_type_ids
):
if
type_id
==
0
and
first_one
<
i
:
raise
ValueError
(
err_msg
)
elif
type_id
==
1
and
first_one
>
i
:
first_one
=
i
elif
type_id
>
1
:
raise
ValueError
(
err_msg
)
return
first_one
vllm/model_executor/models/bert.py
View file @
39052dbc
...
@@ -28,7 +28,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
...
@@ -28,7 +28,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
...
@@ -60,21 +60,13 @@ class BertEmbedding(nn.Module):
...
@@ -60,21 +60,13 @@ class BertEmbedding(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
# Input embeddings.
token_type_ids
=
_decode_token_type_ids
(
input_ids
)
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
...
@@ -350,25 +342,23 @@ class BertModel(nn.Module, SupportsQuant):
...
@@ -350,25 +342,23 @@ class BertModel(nn.Module, SupportsQuant):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
embeddings
=
embedding_class
(
config
)
self
.
embeddings
=
embedding_class
(
self
.
config
)
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
prefix
=
f
"
{
prefix
}
.encoder"
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position
_id
s
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
else
:
else
:
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
positions
)
token_type_ids
=
token_type_ids
)
return
self
.
encoder
(
hidden_states
)
return
self
.
encoder
(
hidden_states
)
def
_load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
_load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
@@ -468,13 +458,11 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
...
@@ -468,13 +458,11 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
=
input_ids
,
return
self
.
model
(
input_ids
=
input_ids
,
position_ids
=
positions
,
positions
=
positions
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
intermediate_tensors
=
intermediate_tensors
)
...
@@ -508,8 +496,53 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
...
@@ -508,8 +496,53 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
})
})
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsV0Only
,
# Here we encode the token type ids together with the input ids.
SupportsCrossEncoding
,
SupportsQuant
):
# Since we use int 32 for the input IDs and the vocabulary size
# is way lower than 2**31, there is room to encode additional
# bits. At the same time, for cross-encoder use cases, the
# token type ids are only 0 or 1, requiring only 1 bit.
# This means that we can store the token type ids in the 31st
# bit. We void the 32nd bit because that would produce a negative
# number, which could be used to signal other things.
#
# The reason for all of this is that all the tensors that are
# passed as input to the forward function of a module marked
# with @support_torch_compile have to be persistent. So to
# avoid adding more persistent tensors in the model runner, we
# encode more information in the same persistent tensor.
#
# Since the *ForClassification module is outside of the BertModel
# which is compiled, we can do the encoding here and then separate
# the information again in the Embedding layer. Since with bit masks
# we can do this entirely with torch operations and without branching,
# it works with torch compile.
TOKEN_TYPE_SHIFT
=
30
def
_encode_token_type_ids
(
input_ids
:
torch
.
Tensor
,
token_type_ids
:
torch
.
Tensor
)
->
None
:
# input_ids can be padded to the right
input_ids
[:
token_type_ids
.
shape
[
0
]].
bitwise_or_
(
token_type_ids
<<
TOKEN_TYPE_SHIFT
)
def
_decode_token_type_ids
(
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ids_mask
=
torch
.
ones
(
input_ids
.
shape
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
)
<<
TOKEN_TYPE_SHIFT
tokens_mask
=
ids_mask
.
bitwise_not
()
token_type_ids
=
input_ids
.
bitwise_and
(
ids_mask
)
>>
TOKEN_TYPE_SHIFT
input_ids
.
bitwise_and_
(
tokens_mask
)
return
token_type_ids
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
This class encapsulates the BertModel and provides an interface for
...
@@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
...
@@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
token_type_ids
is
not
None
:
assert
self
.
bert
.
config
.
vocab_size
<
(
1
<<
TOKEN_TYPE_SHIFT
)
assert
input_ids
is
not
None
_encode_token_type_ids
(
input_ids
,
token_type_ids
)
return
self
.
bert
(
input_ids
=
input_ids
,
return
self
.
bert
(
input_ids
=
input_ids
,
position
_id
s
=
positions
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
)
token_type_ids
=
token_type_ids
)
vllm/model_executor/models/roberta.py
View file @
39052dbc
...
@@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
...
@@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler
,
Pooler
)
DispatchPooler
,
Pooler
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
,
BertModel
from
vllm.model_executor.models.bert
import
(
TOKEN_TYPE_SHIFT
,
BertEmbeddingModel
,
BertModel
,
_decode_token_type_ids
,
_encode_token_type_ids
)
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
)
maybe_prefix
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
class
RobertaEmbedding
(
nn
.
Module
):
class
RobertaEmbedding
(
nn
.
Module
):
...
@@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module):
...
@@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
token_type_ids
=
_decode_token_type_ids
(
input_ids
)
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
...
@@ -107,7 +105,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
...
@@ -107,7 +105,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -119,9 +116,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
...
@@ -119,9 +116,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
position_ids
=
positions
,
position_ids
=
positions
,
padding_idx
=
self
.
padding_idx
)
padding_idx
=
self
.
padding_idx
)
return
self
.
model
(
input_ids
,
return
self
.
model
(
input_ids
=
input_ids
,
positions
,
positions
=
positions
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
intermediate_tensors
=
intermediate_tensors
)
...
@@ -153,8 +149,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
...
@@ -153,8 +149,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return
loader
.
load_weights
(
weights_list
,
mapper
=
mapper
)
return
loader
.
load_weights
(
weights_list
,
mapper
=
mapper
)
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
SupportsV0Only
):
"""A model that uses Roberta to provide embedding functionalities.
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
This class encapsulates the BertModel and provides an interface for
...
@@ -226,11 +221,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -226,11 +221,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
replace_roberta_positions
(
input_ids
=
input_ids
,
replace_roberta_positions
(
input_ids
=
input_ids
,
position_ids
=
positions
,
position_ids
=
positions
,
padding_idx
=
self
.
padding_idx
)
padding_idx
=
self
.
padding_idx
)
if
token_type_ids
is
not
None
:
assert
self
.
roberta
.
config
.
vocab_size
<
(
1
<<
TOKEN_TYPE_SHIFT
)
assert
input_ids
is
not
None
_encode_token_type_ids
(
input_ids
,
token_type_ids
)
return
self
.
roberta
(
input_ids
=
input_ids
,
return
self
.
roberta
(
input_ids
=
input_ids
,
position
_id
s
=
positions
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
)
token_type_ids
=
token_type_ids
)
# Adapted from transformers
# Adapted from transformers
...
...
vllm/pooling_params.py
View file @
39052dbc
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgspec
import
msgspec
...
@@ -46,6 +46,9 @@ class PoolingParams(
...
@@ -46,6 +46,9 @@ class PoolingParams(
requires_token_ids
:
bool
=
False
requires_token_ids
:
bool
=
False
"""Internal use only."""
"""Internal use only."""
extra_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
"""Internal use only."""
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
@
property
@
property
...
@@ -167,7 +170,8 @@ class PoolingParams(
...
@@ -167,7 +170,8 @@ class PoolingParams(
f
"softmax=
{
self
.
softmax
}
, "
f
"softmax=
{
self
.
softmax
}
, "
f
"step_tag_id=
{
self
.
step_tag_id
}
, "
f
"step_tag_id=
{
self
.
step_tag_id
}
, "
f
"returned_token_ids=
{
self
.
returned_token_ids
}
, "
f
"returned_token_ids=
{
self
.
returned_token_ids
}
, "
f
"requires_token_ids=
{
self
.
requires_token_ids
}
)"
)
f
"requires_token_ids=
{
self
.
requires_token_ids
}
, "
f
"extra_kwargs=
{
self
.
extra_kwargs
}
)"
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
assert
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
,
\
assert
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
,
\
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
39052dbc
...
@@ -336,6 +336,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -336,6 +336,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
reorder_batch_threshold
:
Optional
[
int
]
=
None
self
.
reorder_batch_threshold
:
Optional
[
int
]
=
None
def
_init_model_kwargs
(
self
,
num_tokens
:
int
):
model_kwargs
=
dict
[
str
,
Any
]()
num_reqs
=
self
.
input_batch
.
num_reqs
pooling_params
=
self
.
input_batch
.
pooling_metadata
.
pooling_params
num_pooling_reqs
=
len
(
pooling_params
)
if
num_pooling_reqs
==
0
:
return
model_kwargs
assert
num_pooling_reqs
==
num_reqs
token_type_id_requests
=
dict
[
int
,
Any
]()
for
i
,
param
in
enumerate
(
pooling_params
):
if
param
.
extra_kwargs
is
not
None
and
\
(
token_types
:
=
param
.
extra_kwargs
.
get
(
"compressed_token_type_ids"
))
is
not
None
:
token_type_id_requests
[
i
]
=
token_types
if
len
(
token_type_id_requests
)
==
0
:
return
model_kwargs
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
token_type_ids
=
[]
for
i
in
range
(
num_reqs
):
pos
=
token_type_id_requests
.
get
(
i
,
seq_lens
[
i
])
ids
=
(
torch
.
arange
(
seq_lens
[
i
])
>=
pos
).
int
()
token_type_ids
.
append
(
ids
)
model_kwargs
[
"token_type_ids"
]
=
torch
.
concat
(
token_type_ids
).
to
(
device
=
self
.
device
)
return
model_kwargs
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
"""
Update the order of requests in the batch based on the attention
Update the order of requests in the batch based on the attention
...
@@ -1504,12 +1539,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1504,12 +1539,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids
=
None
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
model_mm_kwargs
=
self
.
_extract_mm_kwargs
(
scheduler_output
)
model_mm_kwargs
=
self
.
_extract_mm_kwargs
(
scheduler_output
)
model_kwargs
=
self
.
_init_model_kwargs
(
num_scheduled_tokens
)
else
:
else
:
# For text-only models, we use token ids as input.
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
model_kwargs
=
self
.
_init_model_kwargs
(
num_input_tokens
)
inputs_embeds
=
None
inputs_embeds
=
None
model_mm_kwargs
=
{}
model_mm_kwargs
=
{}
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
...
@@ -1548,6 +1585,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1548,6 +1585,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs
,
model_mm_kwargs
,
device
=
self
.
device
,
device
=
self
.
device
,
),
),
**
model_kwargs
,
)
)
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
...
@@ -2211,6 +2249,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2211,6 +2249,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with
self
.
maybe_dummy_run_with_lora
(
self
.
lora_config
,
with
self
.
maybe_dummy_run_with_lora
(
self
.
lora_config
,
num_scheduled_tokens
):
num_scheduled_tokens
):
model_kwargs
=
self
.
_init_model_kwargs
(
num_tokens
)
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
input_ids
=
None
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
...
@@ -2252,6 +2291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2252,6 +2291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs
,
model_mm_kwargs
,
device
=
self
.
device
,
device
=
self
.
device
,
),
),
**
model_kwargs
,
)
)
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment