Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45badd05
Unverified
Commit
45badd05
authored
Jul 18, 2025
by
Cyrus Leung
Committed by
GitHub
Jul 18, 2025
Browse files
[Core] Set pooling params based on task and model (#21128)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
4adc66f6
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
453 additions
and
227 deletions
+453
-227
tests/models/language/pooling/test_gritlm.py
tests/models/language/pooling/test_gritlm.py
+11
-15
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+33
-16
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+4
-4
vllm/entrypoints/openai/serving_classification.py
vllm/entrypoints/openai/serving_classification.py
+32
-0
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+15
-3
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+13
-5
vllm/entrypoints/openai/serving_pooling.py
vllm/entrypoints/openai/serving_pooling.py
+5
-0
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+21
-9
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+7
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+95
-54
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+8
-4
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+116
-69
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+0
-7
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+8
-4
vllm/pooling_params.py
vllm/pooling_params.py
+25
-16
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+6
-0
vllm/v1/worker/cpu_model_runner.py
vllm/v1/worker/cpu_model_runner.py
+0
-4
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+11
-8
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+39
-9
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+4
-0
No files found.
tests/models/language/pooling/test_gritlm.py
View file @
45badd05
...
...
@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
importlib.util
from
array
import
array
import
numpy
as
np
import
openai
import
pytest
from
scipy.spatial.distance
import
cosine
...
...
@@ -14,10 +12,6 @@ from vllm.config import ModelConfig
from
....utils
import
RemoteOpenAIServer
# GritLM embedding implementation is only supported by XFormers backend.
pytestmark
=
pytest
.
mark
.
skipif
(
not
importlib
.
util
.
find_spec
(
"xformers"
),
reason
=
"GritLM requires XFormers"
)
MODEL_NAME
=
"parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN
=
4000
...
...
@@ -26,11 +20,11 @@ def _arr(arr):
"""
Convert a list of integers to an array of integers.
"""
return
array
(
"i"
,
arr
)
return
np
.
array
(
arr
)
def
test_find_array
():
from
vllm.model_executor.models.gritlm
import
GritLMPool
er
from
vllm.model_executor.models.gritlm
import
GritLM
Mean
Pool
model_config
=
ModelConfig
(
MODEL_NAME
,
...
...
@@ -41,17 +35,19 @@ def test_find_array():
dtype
=
"bfloat16"
,
seed
=
0
,
)
pool
er
=
GritLMPool
er
(
model_config
=
model_config
)
pool
ing
=
GritLM
Mean
Pool
(
model_config
=
model_config
)
arr
=
_arr
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
assert
pooler
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
start_idx
=
0
)
==
3
assert
pooler
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
start_idx
=
1
)
==
3
assert
pooler
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
start_idx
=
5
)
==
-
1
assert
pooler
.
_find_array
(
arr
,
_arr
([
3
,
5
]),
start_idx
=
0
)
==
-
1
assert
pooling
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
start_idx
=
0
)
==
3
assert
pooling
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
start_idx
=
1
)
==
3
assert
pooling
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
start_idx
=
5
)
==
-
1
assert
pooling
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
end_idx
=
3
)
==
-
1
assert
pooling
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
end_idx
=
4
)
==
3
assert
pooling
.
_find_array
(
arr
,
_arr
([
3
,
5
]),
start_idx
=
0
)
==
-
1
with
pytest
.
raises
(
ValueError
):
pool
er
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
start_idx
=-
1
)
pool
ing
.
_find_array
(
arr
,
_arr
([
3
,
4
,
5
]),
start_idx
=-
1
)
def
run_llm_encode
(
...
...
vllm/entrypoints/llm.py
View file @
45badd05
...
...
@@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from
vllm.outputs
import
(
ClassificationRequestOutput
,
EmbeddingRequestOutput
,
PoolingRequestOutput
,
RequestOutput
,
ScoringRequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
,
PoolingTask
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
RequestOutputKind
,
SamplingParams
)
...
...
@@ -964,6 +964,7 @@ class LLM:
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
pooling_task
:
PoolingTask
=
"encode"
,
)
->
list
[
PoolingRequestOutput
]:
...
...
...
@@ -979,6 +980,7 @@ class LLM:
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
pooling_task
:
PoolingTask
=
"encode"
,
)
->
list
[
PoolingRequestOutput
]:
...
...
...
@@ -994,6 +996,7 @@ class LLM:
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
pooling_task
:
PoolingTask
=
"encode"
,
)
->
list
[
PoolingRequestOutput
]:
...
...
...
@@ -1010,6 +1013,7 @@ class LLM:
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
pooling_task
:
PoolingTask
=
"encode"
,
)
->
list
[
PoolingRequestOutput
]:
...
...
...
@@ -1026,6 +1030,7 @@ class LLM:
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
pooling_task
:
PoolingTask
=
"encode"
,
)
->
list
[
PoolingRequestOutput
]:
...
...
...
@@ -1040,6 +1045,7 @@ class LLM:
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
pooling_task
:
PoolingTask
=
"encode"
,
)
->
list
[
PoolingRequestOutput
]:
...
...
...
@@ -1059,6 +1065,7 @@ class LLM:
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
Union
[
list
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
pooling_task
:
PoolingTask
=
"encode"
,
)
->
list
[
PoolingRequestOutput
]:
"""Apply pooling to the hidden states corresponding to the input
prompts.
...
...
@@ -1080,6 +1087,7 @@ class LLM:
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
pooling_task: Override the pooling task to use.
Returns:
A list of `PoolingRequestOutput` objects containing the
...
...
@@ -1116,11 +1124,12 @@ class LLM:
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
elif
isinstance
(
pooling_params
,
PoolingParams
):
pooling_params
.
verify
(
model_config
)
if
isinstance
(
pooling_params
,
PoolingParams
):
pooling_params
.
verify
(
pooling_task
,
model_config
)
else
:
for
pooling_param
in
pooling_params
:
pooling_param
.
verify
(
model_config
)
pooling_param
.
verify
(
pooling_task
,
model_config
)
tokenization_kwargs
=
dict
[
str
,
Any
]()
_validate_truncation_size
(
model_config
.
max_model_len
,
...
...
@@ -1181,12 +1190,15 @@ class LLM:
raise
ValueError
(
"Embedding API is not supported by this model. "
"Please set `--task embed`."
)
items
=
self
.
encode
(
prompts
,
items
=
self
.
encode
(
prompts
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
use_tqdm
=
use_tqdm
,
pooling_params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
pooling_task
=
"embed"
,
)
return
[
EmbeddingRequestOutput
.
from_base
(
item
)
for
item
in
items
]
...
...
@@ -1228,10 +1240,13 @@ class LLM:
"Classification API is not supported by this model. "
"Please set `--task classify`."
)
items
=
self
.
encode
(
prompts
,
items
=
self
.
encode
(
prompts
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
pooling_task
=
"classify"
,
)
return
[
ClassificationRequestOutput
.
from_base
(
item
)
for
item
in
items
]
...
...
@@ -1251,7 +1266,9 @@ class LLM:
truncate_prompt_tokens
=
truncate_prompt_tokens
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
pooling_task
=
"embed"
,
)
encoded_output_1
:
list
[
PoolingRequestOutput
]
=
encoded_output
[
0
:
len
(
text_1
)]
...
...
@@ -1287,7 +1304,7 @@ class LLM:
if
len
(
data_1
)
==
1
:
data_1
=
data_1
*
len
(
data_2
)
pooling_params
=
PoolingParams
(
use_cross_encoder
=
True
)
pooling_params
=
PoolingParams
(
task
=
"score"
)
tokenization_kwargs
:
dict
[
str
,
Any
]
=
{}
_validate_truncation_size
(
self
.
llm_engine
.
model_config
.
max_model_len
,
truncate_prompt_tokens
,
tokenization_kwargs
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
45badd05
...
...
@@ -1347,8 +1347,8 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params]
def
to_pooling_params
(
self
,
*
,
use_cross_encoder
:
bool
=
False
):
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
)
def
to_pooling_params
(
self
):
return
PoolingParams
()
class
RerankRequest
(
OpenAIBaseModel
):
...
...
@@ -1375,8 +1375,8 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params]
def
to_pooling_params
(
self
,
*
,
use_cross_encoder
:
bool
=
False
):
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
)
def
to_pooling_params
(
self
):
return
PoolingParams
()
class
RerankDocument
(
BaseModel
):
...
...
vllm/entrypoints/openai/serving_classification.py
View file @
45badd05
...
...
@@ -6,6 +6,7 @@ from typing import Optional, Union, cast
import
numpy
as
np
from
fastapi
import
Request
from
typing_extensions
import
override
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
...
...
@@ -21,12 +22,14 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.logger
import
init_logger
from
vllm.outputs
import
ClassificationOutput
,
PoolingRequestOutput
from
vllm.pooling_params
import
PoolingParams
logger
=
init_logger
(
__name__
)
class
ClassificationMixin
(
OpenAIServing
):
@
override
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
...
...
@@ -75,6 +78,7 @@ class ClassificationMixin(OpenAIServing):
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
@
override
def
_build_response
(
self
,
ctx
:
ServeContext
,
...
...
@@ -158,3 +162,31 @@ class ServingClassification(ClassificationMixin):
)
return
await
super
().
handle
(
ctx
)
# type: ignore
@
override
def
_validate_request
(
self
,
ctx
:
ClassificationServeContext
,
)
->
Optional
[
ErrorResponse
]:
if
error
:
=
super
().
_validate_request
(
ctx
):
return
error
ctx
.
truncate_prompt_tokens
=
ctx
.
request
.
truncate_prompt_tokens
return
None
@
override
def
_create_pooling_params
(
self
,
ctx
:
ClassificationServeContext
,
)
->
Union
[
PoolingParams
,
ErrorResponse
]:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
return
pooling_params
try
:
pooling_params
.
verify
(
"classify"
,
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
pooling_params
vllm/entrypoints/openai/serving_embedding.py
View file @
45badd05
...
...
@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from
vllm.logger
import
init_logger
from
vllm.outputs
import
(
EmbeddingOutput
,
EmbeddingRequestOutput
,
PoolingRequestOutput
)
from
vllm.pooling_params
import
PoolingParams
logger
=
init_logger
(
__name__
)
...
...
@@ -45,6 +46,7 @@ def _get_embedding(
class
EmbeddingMixin
(
OpenAIServing
):
@
override
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
...
...
@@ -97,6 +99,7 @@ class EmbeddingMixin(OpenAIServing):
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
@
override
def
_build_response
(
self
,
ctx
:
ServeContext
,
...
...
@@ -191,11 +194,20 @@ class OpenAIServingEmbedding(EmbeddingMixin):
ctx
.
truncate_prompt_tokens
=
ctx
.
request
.
truncate_prompt_tokens
pooling_params
=
ctx
.
request
.
to_pooling_params
()
return
None
@
override
def
_create_pooling_params
(
self
,
ctx
:
ServeContext
[
EmbeddingRequest
],
)
->
Union
[
PoolingParams
,
ErrorResponse
]:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
return
pooling_params
try
:
pooling_params
.
verify
(
self
.
model_config
)
pooling_params
.
verify
(
"embed"
,
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
None
return
pooling_params
vllm/entrypoints/openai/serving_engine.py
View file @
45badd05
...
...
@@ -305,6 +305,16 @@ class OpenAIServing:
" Please, select a smaller truncation size."
)
return
None
def
_create_pooling_params
(
self
,
ctx
:
ServeContext
,
)
->
Union
[
PoolingParams
,
ErrorResponse
]:
if
not
hasattr
(
ctx
.
request
,
"to_pooling_params"
):
return
self
.
create_error_response
(
"Request type does not support pooling parameters"
)
return
ctx
.
request
.
to_pooling_params
()
async
def
_prepare_generators
(
self
,
ctx
:
ServeContext
,
...
...
@@ -318,11 +328,9 @@ class OpenAIServing:
trace_headers
=
(
None
if
ctx
.
raw_request
is
None
else
await
self
.
_get_trace_headers
(
ctx
.
raw_request
.
headers
))
if
not
hasattr
(
ctx
.
request
,
"to_pooling_params"
):
return
self
.
create_error_response
(
"Request type does not support pooling parameters"
)
pooling_params
=
ctx
.
request
.
to_pooling_params
()
pooling_params
=
self
.
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
return
pooling_params
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
...
...
vllm/entrypoints/openai/serving_pooling.py
View file @
45badd05
...
...
@@ -142,6 +142,11 @@ class OpenAIServingPooling(OpenAIServing):
try
:
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
"encode"
,
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
vllm/entrypoints/openai/serving_score.py
View file @
45badd05
...
...
@@ -55,14 +55,13 @@ class ServingScores(OpenAIServing):
texts_1
:
list
[
str
],
texts_2
:
list
[
str
],
request
:
Union
[
RerankRequest
,
ScoreRequest
],
request_id
=
str
,
request_id
:
str
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
Union
[
LoRARequest
,
None
]]
=
None
,
prompt_adapter_request
:
Optional
[
Union
[
PromptAdapterRequest
,
None
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
list
[
PoolingRequestOutput
]:
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
input_texts
=
texts_1
+
texts_2
engine_prompts
:
list
[
TokensPrompt
]
=
[]
...
...
@@ -89,6 +88,11 @@ class ServingScores(OpenAIServing):
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
"embed"
,
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
@@ -169,14 +173,13 @@ class ServingScores(OpenAIServing):
data_1
:
Union
[
list
[
str
],
list
[
ScoreContentPartParam
]],
data_2
:
Union
[
list
[
str
],
list
[
ScoreContentPartParam
]],
request
:
Union
[
RerankRequest
,
ScoreRequest
],
request_id
=
str
,
request_id
:
str
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
Union
[
LoRARequest
,
None
]]
=
None
,
prompt_adapter_request
:
Optional
[
Union
[
PromptAdapterRequest
,
None
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
list
[
PoolingRequestOutput
]:
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
request_prompts
:
list
[
str
]
=
[]
engine_prompts
:
list
[
TokensPrompt
]
=
[]
...
...
@@ -245,7 +248,12 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
pooling_params
=
request
.
to_pooling_params
(
use_cross_encoder
=
True
)
pooling_params
=
request
.
to_pooling_params
()
try
:
pooling_params
.
verify
(
"score"
,
self
.
model_config
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
@@ -286,8 +294,7 @@ class ServingScores(OpenAIServing):
request_id
:
str
,
raw_request
:
Optional
[
Request
]
=
None
,
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
)
->
list
[
PoolingRequestOutput
]:
)
->
Union
[
list
[
PoolingRequestOutput
],
ErrorResponse
]:
(
lora_request
,
prompt_adapter_request
,
...
...
@@ -374,6 +381,8 @@ class ServingScores(OpenAIServing):
raw_request
,
request
.
truncate_prompt_tokens
,
)
if
isinstance
(
final_res_batch
,
ErrorResponse
):
return
final_res_batch
return
self
.
request_output_to_score_response
(
final_res_batch
,
...
...
@@ -420,6 +429,9 @@ class ServingScores(OpenAIServing):
raw_request
,
request
.
truncate_prompt_tokens
,
)
if
isinstance
(
final_res_batch
,
ErrorResponse
):
return
final_res_batch
return
self
.
request_output_to_rerank_response
(
final_res_batch
,
request_id
,
...
...
vllm/executor/executor_base.py
View file @
45badd05
...
...
@@ -4,6 +4,7 @@
import
asyncio
import
time
from
abc
import
ABC
,
abstractmethod
from
functools
import
cached_property
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
...
...
@@ -15,6 +16,7 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.pooling_params
import
PoolingTask
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.utils
import
make_async
...
...
@@ -135,6 +137,11 @@ class ExecutorBase(ABC):
return
self
.
collective_rpc
(
rpc_func
)
@
cached_property
# Avoid unnecessary RPC calls
def
supported_pooling_tasks
(
self
)
->
tuple
[
PoolingTask
,
...]:
output
=
self
.
collective_rpc
(
"get_supported_pooling_tasks"
)
return
tuple
({
task
for
tasks
in
output
for
task
in
tasks
})
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Optional
[
List
[
Union
[
SamplerOutput
,
PoolerOutput
]]]:
...
...
vllm/model_executor/layers/pooler.py
View file @
45badd05
...
...
@@ -3,7 +3,7 @@
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
typing
import
Callable
,
Literal
,
Optional
,
TypeVar
,
Union
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -15,13 +15,12 @@ from vllm.config import ModelConfig, PoolerConfig
from
vllm.model_executor.pooling_metadata
import
(
# noqa: E501
PoolingMetadata
as
V0PoolingMetadata
)
from
vllm.model_executor.pooling_metadata
import
PoolingTensors
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
,
PoolingTask
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.v1.pool.metadata
import
PoolingMetadata
as
V1PoolingMetadata
PoolingMetadata
=
Union
[
V0PoolingMetadata
,
V1PoolingMetadata
]
PoolingTask
=
Literal
[
"encode"
,
"embed"
,
"classify"
,
"score"
]
class
PoolingType
(
IntEnum
):
...
...
@@ -67,6 +66,15 @@ class ResolvedPoolingConfig:
)
@
dataclass
(
frozen
=
True
)
class
PoolingParamsUpdate
:
requires_token_ids
:
bool
=
False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def
apply
(
self
,
params
:
PoolingParams
)
->
None
:
params
.
requires_token_ids
=
self
.
requires_token_ids
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
...
...
@@ -93,7 +101,10 @@ class Pooler(nn.Module, ABC):
return
SimplePooler
.
from_config
(
resolved_config
)
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
"""
Construct the pooling parameters to use for a task,
or `None` if the task is not supported.
...
...
@@ -121,6 +132,23 @@ def get_prompt_lens(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
def
get_prompt_token_ids
(
pooling_metadata
:
PoolingMetadata
)
->
list
[
torch
.
Tensor
]:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
assert
pooling_metadata
.
prompt_token_ids
is
not
None
,
(
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
return
[
pooling_metadata
.
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
pooling_metadata
.
prompt_lens
)
]
return
[
torch
.
tensor
(
seq_data_i
.
prompt_token_ids
)
for
seq_data_i
in
pooling_metadata
.
seq_data
.
values
()
]
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
return
PoolerClassify
()
...
...
@@ -165,7 +193,10 @@ class PoolingMethod(nn.Module, ABC):
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
}
"
)
@
abstractmethod
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
raise
NotImplementedError
@
abstractmethod
...
...
@@ -206,11 +237,14 @@ class PoolingMethod(nn.Module, ABC):
class
CLSPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
return
PoolingParams
Update
()
assert_never
(
task
)
...
...
@@ -236,11 +270,14 @@ class CLSPool(PoolingMethod):
class
LastPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
return
PoolingParams
Update
()
assert_never
(
task
)
...
...
@@ -262,9 +299,12 @@ class LastPool(PoolingMethod):
class
AllPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
if
task
==
"encode"
:
return
PoolingParams
()
return
PoolingParams
Update
()
# The equalities are split up to keep mypy happy
if
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
:
...
...
@@ -299,11 +339,14 @@ class AllPool(PoolingMethod):
class
MeanPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
return
PoolingParams
Update
()
assert_never
(
task
)
...
...
@@ -520,8 +563,11 @@ class SimplePooler(Pooler):
self
.
pooling
=
pooling
self
.
head
=
head
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
forward
(
self
,
...
...
@@ -559,27 +605,13 @@ class StepPooler(Pooler):
self
.
step_tag_id
=
step_tag_id
self
.
returned_token_ids
=
returned_token_ids
def
get_prompt_token_ids
(
self
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
]:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
return
[
pooling_metadata
.
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
pooling_metadata
.
prompt_lens
)
]
return
[
torch
.
tensor
(
seq_data_i
.
prompt_token_ids
)
for
seq_data_i
in
pooling_metadata
.
seq_data
.
values
()
]
def
extract_states
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
self
.
get_prompt_token_ids
(
pooling_metadata
)
prompt_token_ids
=
get_prompt_token_ids
(
pooling_metadata
)
pooled_data
=
list
[
torch
.
Tensor
]()
returned_token_ids
=
self
.
returned_token_ids
...
...
@@ -595,9 +627,12 @@ class StepPooler(Pooler):
return
pooled_data
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
if
task
==
"encode"
:
return
PoolingParams
(
logits_processing_need
s_token_ids
=
True
)
return
PoolingParams
Update
(
require
s_token_ids
=
True
)
# The equalities are split up to keep mypy happy
if
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
:
...
...
@@ -650,19 +685,24 @@ class ClassifierPooler(nn.Module):
self
.
cross_encoder_act_fn
=
get_cross_encoder_activation_function
(
config
.
hf_config
)
if
act_fn
is
None
else
act_fn
def
_get_act_fn
(
self
,
use_cross_encoder
:
bool
):
return
(
self
.
cross_encoder_act_fn
if
use_cross_encoder
else
self
.
classification_act_fn
)
def
_get_act_fn
(
self
,
task
:
PoolingTask
):
if
task
==
"encode"
or
task
==
"classify"
:
return
self
.
classification_act_fn
if
task
==
"score"
:
return
self
.
cross_encoder_act_fn
raise
ValueError
(
f
"Unsupported task:
{
task
!
r
}
"
)
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
# The equalities are split up to keep mypy happy
if
task
==
"encode"
or
task
==
"classify"
or
task
==
"score"
:
return
PoolingParamsUpdate
()
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
if
task
==
"encode"
:
return
PoolingParams
()
if
task
==
"embed"
:
return
None
if
task
==
"classify"
:
return
PoolingParams
()
if
task
==
"score"
:
return
PoolingParams
(
use_cross_encoder
=
True
)
assert_never
(
task
)
...
...
@@ -682,27 +722,28 @@ class ClassifierPooler(nn.Module):
else
:
pooled_output
=
[
self
.
classifier
(
data
)
for
data
in
pooled_data
]
task_list
:
list
[
PoolingTask
]
if
isinstance
(
pooling_metadata
,
V0PoolingMetadata
):
use_cross_encoder
_list
=
[
pooling_param
.
use_cross_encoder
for
_
,
pooling_param
in
pooling_metadata
.
seq_groups
task
_list
=
[
task
for
_
,
pooling_param
in
pooling_metadata
.
seq_groups
if
(
task
:
=
pooling_param
.
task
)
is
not
None
]
else
:
use_cross_encoder
_list
=
[
pooling_param
.
use_cross_encoder
for
pooling_param
in
pooling_metadata
.
pooling_params
task
_list
=
[
task
for
pooling_param
in
pooling_metadata
.
pooling_params
if
(
task
:
=
pooling_param
.
task
)
is
not
None
]
assert
len
(
task_list
)
==
len
(
pooled_output
)
# shape of scores: (batch_size, num_labels)
if
all
(
use_cross_encoder
==
use_cross_encoder_list
[
0
]
for
use_cross_encoder
in
use_cross_encoder_list
):
act_fn
=
self
.
_get_act_fn
(
use_cross_encoder_list
[
0
])
if
len
(
set
(
task_list
))
<=
1
:
act_fn
=
self
.
_get_act_fn
(
task_list
[
0
])
scores
=
act_fn
(
pooled_output
)
else
:
scores
=
torch
.
stack
([
self
.
_get_act_fn
(
use_cross_encoder
)(
vecs
)
for
use_cross_encoder
,
vecs
in
zip
(
use_cross_encoder_list
,
pooled_output
)
self
.
_get_act_fn
(
task
)(
vecs
)
for
task
,
vecs
in
zip
(
task_list
,
pooled_output
)
])
return
build_output
(
scores
)
vllm/model_executor/models/bert.py
View file @
45badd05
...
...
@@ -18,13 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
Pooler
,
PoolingMethod
,
PoolingTask
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingType
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.pooling_params
import
Pooling
Params
from
vllm.pooling_params
import
Pooling
Task
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
,
SupportsV0Only
...
...
@@ -91,8 +92,11 @@ class BertPooler(Pooler):
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
forward
(
self
,
...
...
vllm/model_executor/models/gritlm.py
View file @
45badd05
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
array
import
array
from
typing
import
Optional
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
typing_extensions
import
assert_never
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
PoolerHead
,
PoolerNormalize
from
vllm.model_executor.layers.pooler
import
(
Pooler
,
PoolerHead
,
PoolerNormalize
,
PoolingParamsUpdate
,
build_output
,
get_prompt_lens
,
get_prompt_token_ids
)
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.pooling_metadata
import
(
PoolingMetadata
,
PoolingT
ensors
)
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.pooling_params
import
PoolingT
ask
from
vllm.sequence
import
PoolerOutput
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
.interfaces
import
SupportsV0Only
...
...
@@ -20,7 +26,8 @@ from .interfaces import SupportsV0Only
logger
=
init_logger
(
__name__
)
class
GritLMPooler
(
nn
.
Module
):
class
GritLMMeanPool
(
nn
.
Module
):
"""As `MeanPool`, but only includes non-instruction tokens."""
def
__init__
(
self
,
model_config
:
ModelConfig
):
super
().
__init__
()
...
...
@@ -38,8 +45,8 @@ class GritLMPooler(nn.Module):
for
tok
in
[
"<s>"
,
"▁<"
,
"<"
,
"|"
,
"embed"
,
">"
,
"<0x0A>"
,
"user"
]
}
def
tokens_to_ids
(
tokens
:
list
[
str
])
->
array
:
return
array
(
"i"
,
[
self
.
token_ids
[
token
]
for
token
in
tokens
])
def
tokens_to_ids
(
tokens
:
list
[
str
])
->
np
.
nd
array
:
return
np
.
array
([
self
.
token_ids
[
token
]
for
token
in
tokens
])
self
.
user_pattern_ids
=
tokens_to_ids
(
[
"▁<"
,
"|"
,
"user"
,
"|"
,
">"
,
"<0x0A>"
])
...
...
@@ -48,32 +55,44 @@ class GritLMPooler(nn.Module):
self
.
embed_pattern_ids
=
tokens_to_ids
(
[
"▁<"
,
"|"
,
"embed"
,
"|"
,
">"
,
"<0x0A>"
])
self
.
head
=
PoolerHead
(
PoolerNormalize
())
def
_find_array
(
self
,
arr
:
array
,
target
:
array
,
start_idx
:
int
)
->
int
:
def
_find_array
(
self
,
arr
:
np
.
ndarray
,
target
:
np
.
ndarray
,
start_idx
:
int
=
0
,
end_idx
:
Optional
[
int
]
=
None
,
)
->
int
:
"""
Find the first occurrence of target in arr starting from start_idx.
Find the first occurrence of `target` in `arr` starting from
`start_idx`.
Args:
arr: The array to search within
target: The consecutive subsequence to find
start_idx: The starting index to search from
arr: The array to search within.
target: The consecutive subsequence to find.
start_idx: The starting index to search from (inclusive).
end_idx: The ending index to search from (exclusive).
Returns:
int:
The index of the first occurrence of target in arr.
The index of the first occurrence of
`
target
`
in
`
arr
`
.
"""
if
start_idx
<
0
:
raise
ValueError
(
"start_idx must be non-negative"
)
if
not
target
or
not
arr
:
raise
ValueError
(
"Empty arr or target not allowed"
)
raise
ValueError
(
"
`
start_idx
`
must be non-negative"
)
if
len
(
arr
)
==
0
or
len
(
target
)
==
0
:
raise
ValueError
(
"Empty
`
arr
`
or
`
target
`
not allowed"
)
arr_len
=
len
(
arr
)
target_len
=
len
(
target
)
for
i
in
range
(
start_idx
,
len
(
arr
)
-
target_len
+
1
):
if
arr
[
i
:
i
+
target_len
]
==
target
:
if
end_idx
is
None
:
end_idx
=
arr_len
for
i
in
range
(
start_idx
,
min
(
end_idx
,
arr_len
-
target_len
+
1
)):
if
(
arr
[
i
:
i
+
target_len
]
==
target
).
all
():
return
i
return
-
1
def
_get_instruction_len
(
self
,
prompt_token_ids
:
array
)
->
int
:
def
_get_instruction_len
(
self
,
prompt_token_ids
:
np
.
nd
array
)
->
int
:
"""
Get the length of the instruction in the prompt.
...
...
@@ -83,7 +102,6 @@ class GritLMPooler(nn.Module):
The pattern matching is done using integers instead of strings
because the prompt is given as a list of token IDs.
"""
instruction_len
=
0
# Return no instruction in case of missing BOS token.
...
...
@@ -98,7 +116,8 @@ class GritLMPooler(nn.Module):
embed_pattern_ids
=
self
.
embed_pattern_ids
if
self
.
_find_array
(
prompt_token_ids
,
self
.
user_pattern_ids
,
start_idx
=
1
)
==
1
:
start_idx
=
1
,
end_idx
=
2
)
==
1
:
embed_pattern_ids
=
self
.
embed_newline_pattern_ids
# Find the embed pattern in the prompt.
...
...
@@ -116,64 +135,92 @@ class GritLMPooler(nn.Module):
return
instruction_len
def
forward
(
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
# The equalities are split up to keep mypy happy
if
task
==
"encode"
or
task
==
"embed"
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
if
task
==
"classify"
or
task
==
"score"
:
return
None
assert_never
(
task
)
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
"""
Pool the hidden states by summing the embeddings of
non-instruction tokens.
"""
prompts_token_ids
=
[
token_ids
.
prompt_token_ids_array
for
_
,
token_ids
in
pooling_metadata
.
seq_data
.
items
()
]
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
instr_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
"partial prefill not supported with MEAN pooling"
return
hidden_states
[
instr_len
:].
mean
(
dim
=
0
,
dtype
=
torch
.
float32
)
instruction_lens
=
torch
.
tensor
(
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_lens
:
torch
.
Tensor
,
instr_lens
:
torch
.
Tensor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
offset
=
0
pooled_data
=
list
[
torch
.
Tensor
]()
for
prompt_len
,
instr_len
in
zip
(
prompt_lens
,
instr_lens
):
pooled_data
.
append
(
hidden_states
[
offset
+
instr_len
:
offset
+
prompt_len
].
mean
(
dim
=
0
,
dtype
=
torch
.
float32
))
offset
+=
prompt_len
return
pooled_data
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
instr_lens
=
torch
.
tensor
(
[
self
.
_get_instruction_len
(
prompt_
token_ids
)
for
prompt_
token_ids
in
prompt
s
_token_ids
self
.
_get_instruction_len
(
token_ids
.
cpu
().
numpy
()
)
for
token_ids
in
get_
prompt_token_ids
(
pooling_metadata
)
],
device
=
hidden_state
s
.
device
,
device
=
prompt_len
s
.
device
,
)
prompt_lens
=
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
mask
=
torch
.
zeros_like
(
hidden_states
,
dtype
=
torch
.
bool
)
start_idx
=
0
for
prompt_len
,
instruction_len
in
zip
(
prompt_lens
,
instruction_lens
):
end_idx
=
start_idx
+
prompt_len
mask
[
start_idx
+
instruction_len
:
end_idx
]
=
True
start_idx
=
end_idx
if
isinstance
(
hidden_states
,
list
):
return
[
self
.
forward_one
(
h
,
prompt_len
,
instr_len
)
for
h
,
prompt_len
,
instr_len
in
zip
(
hidden_states
,
prompt_lens
,
instr_lens
)
]
masked_hidden_states
=
hidden_states
.
masked_fill
(
~
mask
,
0.0
)
return
self
.
forward_all
(
hidden_states
,
prompt_lens
,
instr_lens
)
sum_embeddings
=
torch
.
zeros
(
len
(
prompt_lens
),
hidden_states
.
size
(
1
),
device
=
hidden_states
.
device
)
start_idx
=
0
for
i
,
prompt_len
in
enumerate
(
prompt_lens
):
end_idx
=
start_idx
+
prompt_len
sum_embeddings
[
i
]
=
masked_hidden_states
[
start_idx
:
end_idx
].
sum
(
dim
=
0
)
start_idx
=
end_idx
class
GritLMPooler
(
Pooler
):
num_non_instruction_tokens
=
prompt_lens
-
instruction_lens
mean_embeddings
=
sum_embeddings
/
num_non_instruction_tokens
.
unsqueeze
(
1
)
def
__init__
(
self
,
model_config
:
ModelConfig
):
super
().
__init__
()
pooled_data
=
self
.
head
(
mean_embeddings
,
pooling_metadata
=
pooling_metadata
)
self
.
pooling
=
GritLMMeanPool
(
model_config
)
self
.
head
=
PoolerHead
(
PoolerNormalize
()
)
pooled_outputs
=
[
PoolingSequenceGroupOutput
(
data
)
for
data
in
pooled_data
]
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
return
self
.
pooling
.
get_pooling_updates
(
task
)
return
PoolerOutput
(
outputs
=
pooled_outputs
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
build_output
(
pooled_data
)
class
GritLM
(
LlamaForCausalLM
,
SupportsV0Only
):
...
...
@@ -202,7 +249,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
prefix
:
str
=
""
,
**
kwargs
,
)
->
None
:
# Use full attention for pooling
# Use full attention for pooling
(this is why V1 is not supported yet)
if
vllm_config
.
model_config
.
runner_type
==
"pooling"
:
hf_config
=
vllm_config
.
model_config
.
hf_config
hf_config
.
is_causal
=
False
...
...
vllm/model_executor/models/interfaces.py
View file @
45badd05
...
...
@@ -599,13 +599,6 @@ def supports_cross_encoding(
return
is_pooling_model
(
model
)
and
_supports_cross_encoding
(
model
)
def
has_step_pooler
(
model
:
Union
[
type
[
object
],
object
])
->
bool
:
"""Check if the model uses step pooler."""
from
vllm.model_executor.layers.pooler
import
StepPooler
return
is_pooling_model
(
model
)
and
isinstance
(
model
.
pooler
,
StepPooler
)
class
SupportsQuant
:
"""The interface required for all models that support quantization."""
...
...
vllm/model_executor/models/modernbert.py
View file @
45badd05
...
...
@@ -14,14 +14,15 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
Pooler
,
PoolingMethod
,
PoolingTask
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingType
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.pooling_params
import
Pooling
Params
from
vllm.pooling_params
import
Pooling
Task
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
...
...
@@ -270,8 +271,11 @@ class ModernBertPooler(Pooler):
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
)
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
get_pooling_updates
(
self
,
task
:
PoolingTask
,
)
->
Optional
[
PoolingParamsUpdate
]:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
forward
(
self
,
...
...
vllm/pooling_params.py
View file @
45badd05
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
import
msgspec
...
...
@@ -10,12 +10,14 @@ from vllm.sampling_params import RequestOutputKind
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
PoolingTask
=
Literal
[
"encode"
,
"embed"
,
"classify"
,
"score"
]
class
PoolingParams
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""API parameters for pooling models.
This
"""API parameters for pooling models.
Attributes:
dimensions: Reduce the dimensions of embeddings
...
...
@@ -24,24 +26,33 @@ class PoolingParams(
dimensions
:
Optional
[
int
]
=
None
use_cross_encoder
:
bool
=
False
"""Internal use only."""
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
logits_processing_needs_token_ids
:
bool
=
Fals
e
task
:
Optional
[
PoolingTask
]
=
Non
e
"""Internal use only."""
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
requires_token_ids
:
bool
=
False
"""Internal use only."""
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
use_cross_encoder
=
self
.
use_cross_encoder
,
logits_processing_needs_token_ids
=
self
.
logits_processing_needs_token_ids
,
task
=
self
.
task
,
requires_token_ids
=
self
.
requires_token_ids
,
)
def
verify
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
def
verify
(
self
,
task
:
PoolingTask
,
model_config
:
"ModelConfig"
)
->
None
:
if
self
.
task
is
None
:
self
.
task
=
task
elif
self
.
task
!=
task
:
msg
=
f
"You cannot overwrite
{
self
.
task
=
!
r
}
with
{
task
=
!
r
}
!"
raise
ValueError
(
msg
)
# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# in this method
if
self
.
dimensions
is
not
None
:
if
not
model_config
.
is_matryoshka
:
raise
ValueError
(
...
...
@@ -61,12 +72,10 @@ class PoolingParams(
raise
ValueError
(
"Dimensions must be greater than 0"
)
def
__repr__
(
self
)
->
str
:
return
(
f
"PoolingParams("
return
(
f
"PoolingParams("
f
"dimensions=
{
self
.
dimensions
}
, "
f
"use_cross_encoder=
{
self
.
use_cross_encoder
}
, "
f
"logits_processing_needs_token_ids=
{
self
.
logits_processing_needs_token_ids
}
)"
)
f
"task=
{
self
.
task
}
, "
f
"requires_token_ids=
{
self
.
requires_token_ids
}
)"
)
def
__post_init__
(
self
)
->
None
:
assert
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
,
\
...
...
vllm/v1/engine/core.py
View file @
45badd05
...
...
@@ -181,6 +181,12 @@ class EngineCore:
def
add_request
(
self
,
request
:
EngineCoreRequest
):
"""Add request to the scheduler."""
if
pooling_params
:
=
request
.
pooling_params
:
supported_pooling_tasks
=
(
self
.
model_executor
.
supported_pooling_tasks
)
if
pooling_params
.
task
not
in
supported_pooling_tasks
:
raise
ValueError
(
f
"Unsupported task:
{
pooling_params
.
task
!
r
}
"
f
"Supported tasks:
{
supported_pooling_tasks
}
"
)
if
request
.
mm_hashes
is
not
None
:
# Here, if hash exists for a multimodal input, then it will be
...
...
vllm/v1/worker/cpu_model_runner.py
View file @
45badd05
...
...
@@ -8,7 +8,6 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models.interfaces
import
has_step_pooler
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
logger
=
init_logger
(
__name__
)
...
...
@@ -54,9 +53,6 @@ class CPUModelRunner(GPUModelRunner):
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
if
has_step_pooler
(
self
.
model
):
self
.
input_batch
.
logits_processing_needs_token_ids
=
True
if
self
.
lora_config
:
self
.
model
=
self
.
load_lora_model
(
self
.
model
,
self
.
model_config
,
self
.
scheduler_config
,
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
45badd05
...
...
@@ -70,7 +70,6 @@ class InputBatch:
vocab_size
:
int
,
block_sizes
:
list
[
int
],
# The block_size of each kv cache group
is_spec_decode
:
bool
=
False
,
logits_processing_needs_token_ids
:
bool
=
False
,
):
self
.
is_spec_decode
=
is_spec_decode
self
.
max_num_reqs
=
max_num_reqs
...
...
@@ -79,8 +78,6 @@ class InputBatch:
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
vocab_size
=
vocab_size
self
.
logits_processing_needs_token_ids
=
(
logits_processing_needs_token_ids
)
self
.
_req_ids
:
list
[
Optional
[
str
]]
=
[]
self
.
req_id_to_index
:
dict
[
str
,
int
]
=
{}
...
...
@@ -233,6 +230,9 @@ class InputBatch:
# req_index -> bad_words_token_ids
self
.
bad_words_token_ids
:
dict
[
int
,
list
[
list
[
int
]]]
=
{}
self
.
logits_processing_needs_token_ids
=
np
.
zeros
(
max_num_reqs
,
dtype
=
bool
)
self
.
req_output_token_ids
:
list
[
Optional
[
list
[
int
]]]
=
[]
# This is updated each time the batch constituents change.
...
...
@@ -365,9 +365,12 @@ class InputBatch:
if
sampling_params
.
bad_words_token_ids
:
self
.
bad_words_token_ids
[
req_index
]
=
sampling_params
.
bad_words_token_ids
elif
pooling_params
:
=
request
.
pooling_params
:
self
.
pooling_params
[
req_id
]
=
pooling_params
self
.
logits_processing_needs_token_ids
[
req_index
]
=
(
pooling_params
.
requires_token_ids
)
else
:
assert
request
.
pooling_params
is
not
None
self
.
pooling_params
[
req_id
]
=
request
.
pooling_params
raise
NotImplementedError
(
request
)
# Add request lora ID
if
request
.
lora_request
:
...
...
@@ -620,9 +623,9 @@ class InputBatch:
copy_slice
(
self
.
repetition_penalties_cpu_tensor
,
self
.
repetition_penalties
,
num_reqs
)
needs_prompt_token_ids
=
(
not
self
.
no_penalties
or
(
self
.
num_reqs
>
0
and
self
.
logits_processing_needs_token_ids
))
needs_prompt_token_ids
=
(
not
self
.
no_penalties
or
self
.
logits_processing_needs_token_ids
[:
num_reqs
].
any
(
))
if
needs_prompt_token_ids
:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
45badd05
...
...
@@ -4,7 +4,7 @@
import
gc
import
time
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
,
cast
,
get_args
import
numpy
as
np
import
torch
...
...
@@ -32,12 +32,13 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaBase
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.models.interfaces
import
(
has_step_pooler
,
is_mixture_of_experts
)
from
vllm.model_executor.models.interfaces
import
is_mixture_of_experts
from
vllm.model_executor.models.interfaces_base
import
(
VllmModelForPooling
,
is_pooling_model
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
,
PoolingTask
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
...
...
@@ -404,6 +405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_id
=
new_req_data
.
req_id
sampling_params
=
new_req_data
.
sampling_params
pooling_params
=
new_req_data
.
pooling_params
if
sampling_params
and
\
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
generator
=
torch
.
Generator
(
device
=
self
.
device
)
...
...
@@ -411,6 +413,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
generator
=
None
if
pooling_params
:
assert
pooling_params
.
task
is
not
None
,
(
"You did not set `task` in the API"
)
model
=
cast
(
VllmModelForPooling
,
self
.
model
)
to_update
=
(
model
.
pooler
.
get_pooling_updates
(
pooling_params
.
task
))
assert
to_update
is
not
None
,
(
f
"
{
pooling_params
.
task
=
}
is not supported by the model"
)
to_update
.
apply
(
pooling_params
)
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
prompt_token_ids
=
new_req_data
.
prompt_token_ids
,
...
...
@@ -1092,6 +1106,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
def
get_supported_pooling_tasks
(
self
)
->
list
[
PoolingTask
]:
model
=
self
.
get_model
()
if
not
is_pooling_model
(
model
):
return
[]
return
[
task
for
task
in
get_args
(
PoolingTask
)
if
model
.
pooler
.
get_pooling_updates
(
task
)
]
def
apply_grammar_bitmask
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -1737,8 +1761,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
model_loader
.
load_weights
(
self
.
model
,
model_config
=
self
.
model_config
)
if
has_step_pooler
(
self
.
model
):
self
.
input_batch
.
logits_processing_needs_token_ids
=
True
if
self
.
lora_config
:
self
.
model
=
self
.
load_lora_model
(
self
.
model
,
self
.
model_config
,
...
...
@@ -2138,16 +2160,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_num_tokens
=
num_tokens
//
num_reqs
model
=
cast
(
VllmModelForPooling
,
self
.
model
)
dummy_task
=
self
.
get_supported_pooling_tasks
()[
0
]
dummy_pooling_params
=
PoolingParams
(
task
=
dummy_task
)
to_update
=
model
.
pooler
.
get_pooling_updates
(
dummy_task
)
assert
to_update
is
not
None
to_update
.
apply
(
dummy_pooling_params
)
dummy_metadata
=
PoolingMetadata
(
prompt_lens
=
torch
.
tensor
([
h
.
shape
[
0
]
for
h
in
hidden_states_list
],
device
=
self
.
device
),
prompt_token_ids
=
torch
.
zeros
((
num_reqs
,
req_num_tokens
),
dtype
=
torch
.
int32
,
device
=
self
.
device
),
pooling_params
=
[
P
ooling
P
arams
()
]
*
num_reqs
)
pooling_params
=
[
dummy_p
ooling
_p
arams
]
*
num_reqs
)
try
:
pooler_output
=
self
.
model
.
pooler
(
hidden_states
=
hidden_states_list
,
pooler_output
=
model
.
pooler
(
hidden_states
=
hidden_states_list
,
pooling_metadata
=
dummy_metadata
)
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
...
...
vllm/v1/worker/gpu_worker.py
View file @
45badd05
...
...
@@ -23,6 +23,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.pooling_params
import
PoolingTask
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
GiB_bytes
,
MemorySnapshot
,
memory_profiling
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
...
...
@@ -309,6 +310,9 @@ class Worker(WorkerBase):
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
def
get_supported_pooling_tasks
(
self
)
->
list
[
PoolingTask
]:
return
self
.
model_runner
.
get_supported_pooling_tasks
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment