Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
39052dbc
Unverified
Commit
39052dbc
authored
Aug 11, 2025
by
Maximilien de Bayser
Committed by
GitHub
Aug 10, 2025
Browse files
Support token_type_ids in V1 with less code changes (#21985)
Signed-off-by:
Max de Bayser
<
mbayser@br.ibm.com
>
parent
9c97a1c3
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
235 additions
and
130 deletions
+235
-130
tests/entrypoints/openai/test_rerank.py
tests/entrypoints/openai/test_rerank.py
+3
-1
tests/entrypoints/openai/test_score.py
tests/entrypoints/openai/test_score.py
+3
-1
tests/models/language/pooling/test_scoring.py
tests/models/language/pooling/test_scoring.py
+9
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+26
-28
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+31
-51
vllm/entrypoints/score_utils.py
vllm/entrypoints/score_utils.py
+37
-3
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+63
-25
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+17
-19
vllm/pooling_params.py
vllm/pooling_params.py
+6
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+40
-0
No files found.
tests/entrypoints/openai/test_rerank.py
View file @
39052dbc
...
...
@@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer):
invocation_output
[
"results"
]):
assert
rerank_result
.
keys
()
==
invocations_result
.
keys
()
assert
rerank_result
[
"relevance_score"
]
==
pytest
.
approx
(
invocations_result
[
"relevance_score"
],
rel
=
0.01
)
invocations_result
[
"relevance_score"
],
rel
=
0.05
)
# TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16
@
pytest
.
mark
.
asyncio
...
...
tests/entrypoints/openai/test_score.py
View file @
39052dbc
...
...
@@ -220,7 +220,9 @@ class TestModel:
invocation_output
[
"data"
]):
assert
score_data
.
keys
()
==
invocation_data
.
keys
()
assert
score_data
[
"score"
]
==
pytest
.
approx
(
invocation_data
[
"score"
],
rel
=
0.01
)
invocation_data
[
"score"
],
rel
=
0.05
)
# TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16
def
test_activation
(
self
,
server
:
RemoteOpenAIServer
,
model
:
dict
[
str
,
Any
]):
...
...
tests/models/language/pooling/test_scoring.py
View file @
39052dbc
...
...
@@ -23,6 +23,15 @@ TEXTS_2 = [
"The capital of Germany is Berlin."
,
]
@
pytest
.
fixture
(
autouse
=
True
)
def
v1
(
run_with_both_engines
):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
DTYPE
=
"half"
...
...
vllm/entrypoints/llm.py
View file @
39052dbc
...
...
@@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template
,
parse_chat_messages
,
resolve_chat_template_content_format
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.score_utils
import
(
ScoreContentPartParam
,
ScoreMultiModalParam
,
_cosine_similarity
,
_validate_score_input_lens
,
compress_token_type_ids
,
get_score_prompt
)
# yapf: enable
from
vllm.entrypoints.utils
import
(
_validate_truncation_size
,
log_non_default_args
)
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
...
...
@@ -1329,6 +1333,7 @@ class LLM:
model_config
=
self
.
llm_engine
.
model_config
pooling_params
.
verify
(
"score"
,
model_config
)
pooling_params_list
=
list
[
PoolingParams
]()
tokenization_kwargs
:
dict
[
str
,
Any
]
=
{}
...
...
@@ -1339,38 +1344,31 @@ class LLM:
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
if
model_config
.
is_multimodal_model
:
for
q
,
d
in
input_pairs
:
_
,
engine_prompt
=
get_score_prompt
(
model_config
=
model_config
,
data_1
=
q
,
data_2
=
d
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
)
model_config
=
self
.
llm_engine
.
model_config
parsed_prompts
.
append
(
engine_prompt
)
else
:
for
q
,
t
in
input_pairs
:
if
model_config
.
use_pad_token
:
# cross_encoder models defaults to using pad_token.
prompt_inputs
=
tokenizer
(
text
=
q
,
# type: ignore[arg-type]
text_pair
=
t
,
# type: ignore[arg-type]
**
tokenization_kwargs
)
else
:
# `llm as reranker` models defaults to not using pad_token.
prompt_inputs
=
tokenizer
(
text
=
q
+
t
,
# type: ignore[operator]
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
parsed_prompts
.
append
(
engine_prompt
)
for
q
,
d
in
input_pairs
:
_
,
engine_prompt
=
get_score_prompt
(
model_config
=
model_config
,
data_1
=
q
,
data_2
=
d
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
)
if
envs
.
VLLM_USE_V1
and
(
token_type_ids
:
=
engine_prompt
.
pop
(
"token_type_ids"
,
None
)):
params
=
pooling_params
.
clone
()
compressed
=
compress_token_type_ids
(
token_type_ids
)
params
.
extra_kwargs
=
{
"compressed_token_type_ids"
:
compressed
}
pooling_params_list
.
append
(
params
)
else
:
pooling_params_list
.
append
(
pooling_params
)
parsed_prompts
.
append
(
engine_prompt
)
self
.
_validate_and_add_requests
(
prompts
=
parsed_prompts
,
params
=
pooling_params
,
params
=
pooling_params
_list
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
)
...
...
vllm/entrypoints/openai/serving_score.py
View file @
39052dbc
...
...
@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
from
fastapi
import
Request
from
vllm
import
envs
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
...
...
@@ -17,11 +18,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
ScoreResponseData
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.score_utils
import
(
ScoreContentPartParam
,
ScoreMultiModalParam
,
_cosine_similarity
,
_validate_score_input_lens
,
compress_token_type_ids
,
get_score_prompt
)
# yapf: enable
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
...
...
@@ -158,6 +163,8 @@ class ServingScores(OpenAIServing):
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
)
self
.
_validate_input
(
request
,
engine_prompt
[
"prompt_token_ids"
],
full_prompt
)
if
request
.
mm_processor_kwargs
is
not
None
:
engine_prompt
[
"mm_processor_kwargs"
]
=
request
.
mm_processor_kwargs
...
...
@@ -188,64 +195,27 @@ class ServingScores(OpenAIServing):
input_pairs
=
[(
t1
,
t2
)
for
t1
,
t2
in
zip
(
data_1
,
data_2
)]
if
self
.
model_config
.
is_multimodal_model
:
preprocess_async
=
make_async
(
self
.
_preprocess_score
,
executor
=
self
.
_tokenizer_executor
)
preprocess_async
=
make_async
(
self
.
_preprocess_score
,
executor
=
self
.
_tokenizer_executor
)
preprocessed_prompts
=
await
asyncio
.
gather
(
*
(
preprocess_async
(
request
=
request
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
data_1
=
t1
,
data_2
=
t2
)
for
t1
,
t2
in
input_pairs
))
preprocessed_prompts
=
await
asyncio
.
gather
(
*
(
preprocess_async
(
request
=
request
,
tokenizer
=
tokenizer
,
tokenization_kwargs
=
tokenization_kwargs
,
data_1
=
t1
,
data_2
=
t2
)
for
t1
,
t2
in
input_pairs
))
for
full_prompt
,
engine_prompt
in
preprocessed_prompts
:
request_prompts
.
append
(
full_prompt
)
engine_prompts
.
append
(
engine_prompt
)
else
:
tokenize_async
=
make_async
(
tokenizer
.
__call__
,
executor
=
self
.
_tokenizer_executor
)
use_pad_token
=
self
.
model_config
.
use_pad_token
if
use_pad_token
:
# cross_encoder models defaults to using pad_token.
tokenized_prompts
=
await
asyncio
.
gather
(
*
(
tokenize_async
(
text
=
t1
,
# type: ignore[arg-type]
text_pair
=
t2
,
# type: ignore[arg-type]
**
tokenization_kwargs
)
for
t1
,
t2
in
input_pairs
))
else
:
# `llm as reranker` models defaults to not using pad_token.
tokenized_prompts
=
await
asyncio
.
gather
(
*
(
tokenize_async
(
text
=
t1
+
# type: ignore[operator]
t2
,
**
tokenization_kwargs
)
for
t1
,
t2
in
input_pairs
))
for
prompt_inputs
,
(
t1
,
t2
)
in
zip
(
tokenized_prompts
,
input_pairs
):
sep_token
=
tokenizer
.
sep_token
if
(
tokenizer
.
sep_token
and
use_pad_token
)
else
''
request_prompt
=
f
"
{
t1
}{
sep_token
}{
t2
}
"
input_ids
=
prompt_inputs
[
"input_ids"
]
text_token_prompt
=
\
self
.
_validate_input
(
request
,
input_ids
,
request_prompt
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
text_token_prompt
[
"prompt_token_ids"
],
token_type_ids
=
prompt_inputs
.
get
(
"token_type_ids"
))
request_prompts
.
append
(
request_prompt
)
engine_prompts
.
append
(
engine_prompt
)
for
full_prompt
,
engine_prompt
in
preprocessed_prompts
:
request_prompts
.
append
(
full_prompt
)
engine_prompts
.
append
(
engine_prompt
)
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
pooling_params
=
request
.
to_pooling_params
()
default_
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
"score"
,
self
.
model_config
)
default_
pooling_params
.
verify
(
"score"
,
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
...
...
@@ -254,9 +224,19 @@ class ServingScores(OpenAIServing):
self
.
_log_inputs
(
request_id_item
,
request_prompts
[
i
],
params
=
pooling_params
,
params
=
default_
pooling_params
,
lora_request
=
lora_request
)
if
envs
.
VLLM_USE_V1
and
(
token_type_ids
:
=
engine_prompt
.
pop
(
"token_type_ids"
,
None
)):
pooling_params
=
default_pooling_params
.
clone
()
compressed
=
compress_token_type_ids
(
token_type_ids
)
pooling_params
.
extra_kwargs
=
{
"compressed_token_type_ids"
:
compressed
}
else
:
pooling_params
=
(
default_pooling_params
)
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
pooling_params
,
...
...
vllm/entrypoints/score_utils.py
View file @
39052dbc
...
...
@@ -184,15 +184,49 @@ def get_score_prompt(
model_config
,
tokenizer
,
)
from
vllm.model_executor.model_loader
import
get_model_cls
full_prompt
=
apply_score_template
(
model_config
,
prompt_1
,
prompt_2
)
prompt_inputs
=
tokenizer
(
full_prompt
,
**
tokenization_kwargs
)
model
=
get_model_cls
(
model_config
)
if
supports_score_template
(
model
):
full_prompt
=
apply_score_template
(
model_config
,
prompt_1
,
prompt_2
)
prompt_inputs
=
tokenizer
(
full_prompt
,
**
tokenization_kwargs
)
elif
model_config
.
use_pad_token
:
# cross_encoder models defaults to using pad_token.
prompt_inputs
=
tokenizer
(
text
=
prompt_1
,
text_pair
=
prompt_2
,
**
tokenization_kwargs
)
full_prompt
=
tokenizer
.
decode
(
prompt_inputs
[
"input_ids"
])
else
:
# `llm as reranker` models defaults to not using pad_token.
full_prompt
=
prompt_1
+
prompt_2
prompt_inputs
=
tokenizer
(
text
=
full_prompt
,
**
tokenization_kwargs
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"input_ids"
])
if
(
token_type_ids
:
=
prompt_inputs
.
get
(
"token_type_ids"
))
is
not
None
:
engine_prompt
[
"token_type_ids"
]
=
token_type_ids
post_process_tokens
(
model_config
,
engine_prompt
)
if
mm_data
is
not
None
:
engine_prompt
[
"multi_modal_data"
]
=
mm_data
return
full_prompt
,
engine_prompt
def
compress_token_type_ids
(
token_type_ids
:
list
[
int
])
->
int
:
"""
Return position of the first 1 or the length of the list
if not found.
"""
first_one
=
len
(
token_type_ids
)
err_msg
=
"Token type ids are expected to be a sequence"
\
" of zeros followed by a sequence of ones"
for
i
,
type_id
in
enumerate
(
token_type_ids
):
if
type_id
==
0
and
first_one
<
i
:
raise
ValueError
(
err_msg
)
elif
type_id
==
1
and
first_one
>
i
:
first_one
=
i
elif
type_id
>
1
:
raise
ValueError
(
err_msg
)
return
first_one
vllm/model_executor/models/bert.py
View file @
39052dbc
...
...
@@ -28,7 +28,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
...
...
@@ -60,21 +60,13 @@ class BertEmbedding(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
# Input embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
token_type_ids
=
_decode_token_type_ids
(
input_ids
)
# Position embeddings.
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
...
...
@@ -350,25 +342,23 @@ class BertModel(nn.Module, SupportsQuant):
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
embeddings
=
embedding_class
(
config
)
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
embeddings
=
embedding_class
(
self
.
config
)
self
.
encoder
=
BertEncoder
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position
_id
s
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
position_ids
=
positions
)
return
self
.
encoder
(
hidden_states
)
def
_load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
@@ -468,13 +458,11 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
=
input_ids
,
position_ids
=
positions
,
token_type_ids
=
token_type_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
...
...
@@ -508,8 +496,53 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
})
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsV0Only
,
SupportsCrossEncoding
,
SupportsQuant
):
# Here we encode the token type ids together with the input ids.
# Since we use int 32 for the input IDs and the vocabulary size
# is way lower than 2**31, there is room to encode additional
# bits. At the same time, for cross-encoder use cases, the
# token type ids are only 0 or 1, requiring only 1 bit.
# This means that we can store the token type ids in the 31st
# bit. We void the 32nd bit because that would produce a negative
# number, which could be used to signal other things.
#
# The reason for all of this is that all the tensors that are
# passed as input to the forward function of a module marked
# with @support_torch_compile have to be persistent. So to
# avoid adding more persistent tensors in the model runner, we
# encode more information in the same persistent tensor.
#
# Since the *ForClassification module is outside of the BertModel
# which is compiled, we can do the encoding here and then separate
# the information again in the Embedding layer. Since with bit masks
# we can do this entirely with torch operations and without branching,
# it works with torch compile.
TOKEN_TYPE_SHIFT
=
30
def
_encode_token_type_ids
(
input_ids
:
torch
.
Tensor
,
token_type_ids
:
torch
.
Tensor
)
->
None
:
# input_ids can be padded to the right
input_ids
[:
token_type_ids
.
shape
[
0
]].
bitwise_or_
(
token_type_ids
<<
TOKEN_TYPE_SHIFT
)
def
_decode_token_type_ids
(
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ids_mask
=
torch
.
ones
(
input_ids
.
shape
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
)
<<
TOKEN_TYPE_SHIFT
tokens_mask
=
ids_mask
.
bitwise_not
()
token_type_ids
=
input_ids
.
bitwise_and
(
ids_mask
)
>>
TOKEN_TYPE_SHIFT
input_ids
.
bitwise_and_
(
tokens_mask
)
return
token_type_ids
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
...
...
@@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
token_type_ids
is
not
None
:
assert
self
.
bert
.
config
.
vocab_size
<
(
1
<<
TOKEN_TYPE_SHIFT
)
assert
input_ids
is
not
None
_encode_token_type_ids
(
input_ids
,
token_type_ids
)
return
self
.
bert
(
input_ids
=
input_ids
,
position
_id
s
=
positions
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
token_type_ids
=
token_type_ids
)
intermediate_tensors
=
intermediate_tensors
)
vllm/model_executor/models/roberta.py
View file @
39052dbc
...
...
@@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler
,
Pooler
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
,
BertModel
from
vllm.model_executor.models.bert
import
(
TOKEN_TYPE_SHIFT
,
BertEmbeddingModel
,
BertModel
,
_decode_token_type_ids
,
_encode_token_type_ids
)
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
)
from
vllm.sequence
import
IntermediateTensors
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
class
RobertaEmbedding
(
nn
.
Module
):
...
...
@@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
input_shape
=
input_ids
.
size
()
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
# Position embeddings.
token_type_ids
=
_decode_token_type_ids
(
input_ids
)
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
inputs_embeds
+
token_type_embeddings
+
position_embeddings
...
...
@@ -107,7 +105,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -119,9 +116,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
position_ids
=
positions
,
padding_idx
=
self
.
padding_idx
)
return
self
.
model
(
input_ids
,
positions
,
token_type_ids
=
token_type_ids
,
return
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
...
...
@@ -153,8 +149,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return
loader
.
load_weights
(
weights_list
,
mapper
=
mapper
)
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
SupportsV0Only
):
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
...
...
@@ -226,11 +221,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
replace_roberta_positions
(
input_ids
=
input_ids
,
position_ids
=
positions
,
padding_idx
=
self
.
padding_idx
)
if
token_type_ids
is
not
None
:
assert
self
.
roberta
.
config
.
vocab_size
<
(
1
<<
TOKEN_TYPE_SHIFT
)
assert
input_ids
is
not
None
_encode_token_type_ids
(
input_ids
,
token_type_ids
)
return
self
.
roberta
(
input_ids
=
input_ids
,
position
_id
s
=
positions
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
token_type_ids
=
token_type_ids
)
intermediate_tensors
=
intermediate_tensors
)
# Adapted from transformers
...
...
vllm/pooling_params.py
View file @
39052dbc
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
copy
import
deepcopy
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgspec
...
...
@@ -46,6 +46,9 @@ class PoolingParams(
requires_token_ids
:
bool
=
False
"""Internal use only."""
extra_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
"""Internal use only."""
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
@
property
...
...
@@ -167,7 +170,8 @@ class PoolingParams(
f
"softmax=
{
self
.
softmax
}
, "
f
"step_tag_id=
{
self
.
step_tag_id
}
, "
f
"returned_token_ids=
{
self
.
returned_token_ids
}
, "
f
"requires_token_ids=
{
self
.
requires_token_ids
}
)"
)
f
"requires_token_ids=
{
self
.
requires_token_ids
}
, "
f
"extra_kwargs=
{
self
.
extra_kwargs
}
)"
)
def
__post_init__
(
self
)
->
None
:
assert
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
,
\
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
39052dbc
...
...
@@ -336,6 +336,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
reorder_batch_threshold
:
Optional
[
int
]
=
None
def
_init_model_kwargs
(
self
,
num_tokens
:
int
):
model_kwargs
=
dict
[
str
,
Any
]()
num_reqs
=
self
.
input_batch
.
num_reqs
pooling_params
=
self
.
input_batch
.
pooling_metadata
.
pooling_params
num_pooling_reqs
=
len
(
pooling_params
)
if
num_pooling_reqs
==
0
:
return
model_kwargs
assert
num_pooling_reqs
==
num_reqs
token_type_id_requests
=
dict
[
int
,
Any
]()
for
i
,
param
in
enumerate
(
pooling_params
):
if
param
.
extra_kwargs
is
not
None
and
\
(
token_types
:
=
param
.
extra_kwargs
.
get
(
"compressed_token_type_ids"
))
is
not
None
:
token_type_id_requests
[
i
]
=
token_types
if
len
(
token_type_id_requests
)
==
0
:
return
model_kwargs
seq_lens
=
self
.
seq_lens
[:
num_reqs
]
token_type_ids
=
[]
for
i
in
range
(
num_reqs
):
pos
=
token_type_id_requests
.
get
(
i
,
seq_lens
[
i
])
ids
=
(
torch
.
arange
(
seq_lens
[
i
])
>=
pos
).
int
()
token_type_ids
.
append
(
ids
)
model_kwargs
[
"token_type_ids"
]
=
torch
.
concat
(
token_type_ids
).
to
(
device
=
self
.
device
)
return
model_kwargs
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
Update the order of requests in the batch based on the attention
...
...
@@ -1504,12 +1539,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
model_mm_kwargs
=
self
.
_extract_mm_kwargs
(
scheduler_output
)
model_kwargs
=
self
.
_init_model_kwargs
(
num_scheduled_tokens
)
else
:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
model_kwargs
=
self
.
_init_model_kwargs
(
num_input_tokens
)
inputs_embeds
=
None
model_mm_kwargs
=
{}
if
self
.
uses_mrope
:
...
...
@@ -1548,6 +1585,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs
,
device
=
self
.
device
,
),
**
model_kwargs
,
)
if
self
.
use_aux_hidden_state_outputs
:
...
...
@@ -2211,6 +2249,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with
self
.
maybe_dummy_run_with_lora
(
self
.
lora_config
,
num_scheduled_tokens
):
model_kwargs
=
self
.
_init_model_kwargs
(
num_tokens
)
if
self
.
supports_mm_inputs
:
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
...
...
@@ -2252,6 +2291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs
,
device
=
self
.
device
,
),
**
model_kwargs
,
)
if
self
.
use_aux_hidden_state_outputs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment