Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
006693ed
Commit
006693ed
authored
Dec 01, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.11.2' into v0.11.2-ori
parents
4b51e6f1
275de341
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
466 additions
and
57 deletions
+466
-57
examples/offline_inference/kv_load_failure_recovery/decode_example.py
...line_inference/kv_load_failure_recovery/decode_example.py
+85
-0
examples/offline_inference/kv_load_failure_recovery/prefill_example.py
...ine_inference/kv_load_failure_recovery/prefill_example.py
+58
-0
examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py
...v_load_failure_recovery/rogue_shared_storage_connector.py
+145
-0
examples/offline_inference/kv_load_failure_recovery/run.sh
examples/offline_inference/kv_load_failure_recovery/run.sh
+33
-0
examples/offline_inference/llm_engine_example.py
examples/offline_inference/llm_engine_example.py
+1
-1
examples/offline_inference/load_sharded_state.py
examples/offline_inference/load_sharded_state.py
+2
-2
examples/offline_inference/logits_processor/custom.py
examples/offline_inference/logits_processor/custom.py
+17
-4
examples/offline_inference/logits_processor/custom_req.py
examples/offline_inference/logits_processor/custom_req.py
+11
-10
examples/offline_inference/logits_processor/custom_req_init.py
...les/offline_inference/logits_processor/custom_req_init.py
+9
-10
examples/offline_inference/lora_with_quantization_inference.py
...les/offline_inference/lora_with_quantization_inference.py
+3
-4
examples/offline_inference/mlpspeculator.py
examples/offline_inference/mlpspeculator.py
+1
-2
examples/offline_inference/multilora_inference.py
examples/offline_inference/multilora_inference.py
+2
-4
examples/offline_inference/openai_batch/README.md
examples/offline_inference/openai_batch/README.md
+11
-3
examples/offline_inference/pooling/README.md
examples/offline_inference/pooling/README.md
+19
-1
examples/offline_inference/pooling/embed_jina_embeddings_v3.py
...les/offline_inference/pooling/embed_jina_embeddings_v3.py
+1
-1
examples/offline_inference/pooling/embed_matryoshka_fy.py
examples/offline_inference/pooling/embed_matryoshka_fy.py
+1
-1
examples/offline_inference/pooling/multi_vector_retrieval.py
examples/offline_inference/pooling/multi_vector_retrieval.py
+56
-0
examples/offline_inference/pooling/ner.py
examples/offline_inference/pooling/ner.py
+2
-2
examples/offline_inference/pooling/prithvi_geospatial_mae.py
examples/offline_inference/pooling/prithvi_geospatial_mae.py
+3
-3
examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py
..._inference/pooling/prithvi_geospatial_mae_io_processor.py
+6
-9
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
examples/offline_inference/kv_load_failure_recovery/decode_example.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
def
read_prompts
():
"""Read prompts from prefill_output.txt"""
prompts
=
[]
try
:
with
open
(
"prefill_output.txt"
)
as
f
:
for
line
in
f
:
prompts
.
append
(
line
.
strip
())
print
(
f
"Loaded
{
len
(
prompts
)
}
prompts from prefill_output.txt"
)
return
prompts
except
FileNotFoundError
:
print
(
"Error: prefill_output.txt file not found"
)
exit
(
-
1
)
def
main
():
prompts
=
read_prompts
()
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
,
max_tokens
=
10
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--simulate-failure"
,
action
=
"store_true"
,
help
=
"Simulate KV load failure."
)
parser
.
add_argument
(
"--async-load"
,
action
=
"store_true"
,
help
=
"Simulate async KV load"
)
args
=
parser
.
parse_args
()
if
args
.
simulate_failure
:
ktc
=
KVTransferConfig
(
kv_connector
=
"RogueSharedStorageConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"shared_storage_path"
:
"local_storage"
,
"async_load"
:
args
.
async_load
,
},
kv_connector_module_path
=
"rogue_shared_storage_connector"
,
)
out_file
=
(
"async_decode_recovered_output.txt"
if
args
.
async_load
else
"sync_decode_recovered_output.txt"
)
else
:
ktc
=
KVTransferConfig
(
kv_connector
=
"SharedStorageConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"shared_storage_path"
:
"local_storage"
,
},
)
out_file
=
"decode_output.txt"
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.8
,
max_num_batched_tokens
=
64
,
max_num_seqs
=
16
,
kv_transfer_config
=
ktc
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
sep_str
=
"-"
*
30
with
open
(
out_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
out_str
=
f
"Prompt:
{
prompt
!
r
}
\n
Generated text:
{
generated_text
!
r
}
"
print
(
out_str
)
print
(
sep_str
)
f
.
write
(
out_str
)
f
.
write
(
sep_str
)
if
__name__
==
"__main__"
:
main
()
examples/offline_inference/kv_load_failure_recovery/prefill_example.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
def
read_prompts
():
context
=
"Hi "
*
1000
context2
=
"Hey "
*
500
return
[
context
+
"Hello, my name is"
,
context
+
"The capital of France is"
,
context2
+
"Your name is"
,
context2
+
"The capital of China is"
,
]
def
main
():
prompts
=
read_prompts
()
sampling_params
=
SamplingParams
(
temperature
=
0
,
top_p
=
0.95
,
max_tokens
=
1
)
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.8
,
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"SharedStorageConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"shared_storage_path"
:
"local_storage"
},
),
)
# , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
)
new_prompts
=
[]
print
(
"-"
*
30
)
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
new_prompts
.
append
(
prompt
+
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
\n
Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
30
)
# Write new_prompts to prefill_output.txt
with
open
(
"prefill_output.txt"
,
"w"
)
as
f
:
for
prompt
in
new_prompts
:
f
.
write
(
prompt
+
"
\n
"
)
print
(
f
"Saved
{
len
(
new_prompts
)
}
prompts to prefill_output.txt"
)
if
__name__
==
"__main__"
:
main
()
examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import
logging
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorMetadata
,
KVConnectorRole
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector
import
(
SharedStorageConnector
,
SharedStorageConnectorMetadata
,
)
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
logger
=
logging
.
getLogger
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
@
dataclass
class
RogueSharedStorageConnectorMetadata
(
SharedStorageConnectorMetadata
):
req_to_block_ids
:
dict
[
str
,
set
[
int
]]
=
field
(
default_factory
=
dict
)
@
classmethod
def
from_base
(
cls
,
base
:
SharedStorageConnectorMetadata
):
return
cls
(
requests
=
base
.
requests
)
class
RogueSharedStorageConnector
(
SharedStorageConnector
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
_async_load
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"async_load"
,
False
)
self
.
_invalid_block_ids
:
set
=
None
self
.
_seen_requests
:
set
=
set
()
self
.
_req_to_block_ids
:
dict
[
str
,
list
[
int
]]
=
dict
()
def
bind_connector_metadata
(
self
,
connector_metadata
:
KVConnectorMetadata
)
->
None
:
assert
isinstance
(
connector_metadata
,
RogueSharedStorageConnectorMetadata
)
index
,
failed_request
=
next
(
(
(
i
,
x
)
for
i
,
x
in
enumerate
(
connector_metadata
.
requests
)
if
not
x
.
is_store
),
(
None
,
None
),
)
if
index
is
not
None
:
del
connector_metadata
.
requests
[
index
]
self
.
_invalid_block_ids
=
set
(
(
failed_request
.
slot_mapping
[::
self
.
_block_size
]
//
self
.
_block_size
).
tolist
()
)
logger
.
info
(
"Simulating failure to load all KV blocks for the "
"first load request. Total blocks: %d"
,
len
(
self
.
_invalid_block_ids
),
)
super
().
bind_connector_metadata
(
connector_metadata
)
def
clear_connector_metadata
(
self
)
->
None
:
self
.
_invalid_block_ids
=
None
super
().
clear_connector_metadata
()
def
start_load_kv
(
self
,
forward_context
:
ForwardContext
,
**
kwargs
)
->
None
:
if
self
.
_async_load
and
forward_context
.
attn_metadata
is
None
:
# Bypass sanity check in super().start_load_kv
forward_context
.
attn_metadata
=
"None"
super
().
start_load_kv
(
forward_context
,
**
kwargs
)
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
]
)
->
tuple
[
set
[
str
]
|
None
,
set
[
str
]
|
None
]:
if
self
.
_async_load
:
meta
=
self
.
_get_connector_metadata
()
assert
isinstance
(
meta
,
RogueSharedStorageConnectorMetadata
)
if
meta
.
req_to_block_ids
:
return
None
,
set
(
meta
.
req_to_block_ids
)
return
None
,
None
def
get_block_ids_with_load_errors
(
self
)
->
set
[
int
]:
return
self
.
_invalid_block_ids
def
get_num_new_matched_tokens
(
self
,
request
:
Request
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
,
bool
]:
if
request
.
request_id
in
self
.
_seen_requests
:
return
0
,
False
self
.
_seen_requests
.
add
(
request
.
request_id
)
num_tokens
,
_
=
super
().
get_num_new_matched_tokens
(
request
,
num_computed_tokens
)
return
num_tokens
,
self
.
_async_load
and
num_tokens
>
0
def
update_state_after_alloc
(
self
,
request
:
Request
,
blocks
:
KVCacheBlocks
,
num_external_tokens
:
int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
super
().
update_state_after_alloc
(
request
,
blocks
,
num_external_tokens
)
if
num_external_tokens
>
0
:
self
.
_req_to_block_ids
[
request
.
request_id
]
=
blocks
.
get_block_ids
()[
0
]
def
build_connector_meta
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
KVConnectorMetadata
:
if
not
self
.
_async_load
:
base
=
super
().
build_connector_meta
(
scheduler_output
)
meta
=
RogueSharedStorageConnectorMetadata
.
from_base
(
base
)
else
:
meta
=
RogueSharedStorageConnectorMetadata
()
if
self
.
_requests_need_load
:
for
req_id
,
request
in
self
.
_requests_need_load
.
items
():
meta
.
add_request
(
token_ids
=
request
.
prompt_token_ids
,
block_ids
=
self
.
_req_to_block_ids
[
req_id
],
block_size
=
self
.
_block_size
,
is_store
=
False
,
mm_hashes
=
[],
)
# Clear state
self
.
_requests_need_load
.
clear
()
meta
.
req_to_block_ids
=
self
.
_req_to_block_ids
self
.
_req_to_block_ids
=
dict
()
return
meta
examples/offline_inference/kv_load_failure_recovery/run.sh
0 → 100755
View file @
006693ed
#!/bin/bash
# Constants
SHARED_STORAGE_DIR
=
"local_storage"
PREFILL_OUTPUT
=
"prefill_output.txt"
DECODE_OUTPUT
=
"decode_output.txt"
SYNC_DECODE_RECOVERED_OUTPUT
=
"sync_decode_recovered_output.txt"
ASYNC_DECODE_RECOVERED_OUTPUT
=
"async_decode_recovered_output.txt"
# Cleanup
rm
-rf
"
$SHARED_STORAGE_DIR
"
rm
-f
"
$PREFILL_OUTPUT
"
"
$DECODE_OUTPUT
"
"
$SYNC_DECODE_RECOVERED_OUTPUT
"
"
$ASYNC_DECODE_RECOVERED_OUTPUT
"
# Run inference examples
VLLM_ENABLE_V1_MULTIPROCESSING
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 prefill_example.py
VLLM_ENABLE_V1_MULTIPROCESSING
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 decode_example.py
VLLM_ENABLE_V1_MULTIPROCESSING
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 decode_example.py
--simulate-failure
VLLM_ENABLE_V1_MULTIPROCESSING
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 decode_example.py
--simulate-failure
--async-load
# Compare outputs
if
!
cmp
-s
"
$DECODE_OUTPUT
"
"
$SYNC_DECODE_RECOVERED_OUTPUT
"
;
then
echo
"❌ Outputs differ: sync recovery failed."
diff
-u
"
$DECODE_OUTPUT
"
"
$SYNC_DECODE_RECOVERED_OUTPUT
"
exit
1
fi
if
!
cmp
-s
"
$DECODE_OUTPUT
"
"
$ASYNC_DECODE_RECOVERED_OUTPUT
"
;
then
echo
"❌ Outputs differ: async recovery failed."
diff
-u
"
$DECODE_OUTPUT
"
"
$ASYNC_DECODE_RECOVERED_OUTPUT
"
exit
1
fi
echo
"✅ Outputs match: recovery successful."
examples/offline_inference/llm_engine_example.py
View file @
006693ed
...
@@ -8,7 +8,7 @@ for processing prompts with various sampling parameters.
...
@@ -8,7 +8,7 @@ for processing prompts with various sampling parameters.
import
argparse
import
argparse
from
vllm
import
EngineArgs
,
LLMEngine
,
RequestOutput
,
SamplingParams
from
vllm
import
EngineArgs
,
LLMEngine
,
RequestOutput
,
SamplingParams
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
def
create_test_prompts
()
->
list
[
tuple
[
str
,
SamplingParams
]]:
def
create_test_prompts
()
->
list
[
tuple
[
str
,
SamplingParams
]]:
...
...
examples/offline_inference/load_sharded_state.py
View file @
006693ed
...
@@ -11,7 +11,7 @@ python save_sharded_state.py \
...
@@ -11,7 +11,7 @@ python save_sharded_state.py \
--model /path/to/load
\
--model /path/to/load
\
--quantization deepspeedfp
\
--quantization deepspeedfp
\
--tensor-parallel-size 8
\
--tensor-parallel-size 8
\
--output /path/to/save/sharded/model
e
--output /path/to/save/sharded/model
python load_sharded_state.py
\
python load_sharded_state.py
\
--model /path/to/saved/sharded/model
\
--model /path/to/saved/sharded/model
\
...
@@ -25,7 +25,7 @@ python load_sharded_state.py \
...
@@ -25,7 +25,7 @@ python load_sharded_state.py \
import
dataclasses
import
dataclasses
from
vllm
import
LLM
,
EngineArgs
,
SamplingParams
from
vllm
import
LLM
,
EngineArgs
,
SamplingParams
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
def
parse_args
():
def
parse_args
():
...
...
examples/offline_inference/logits_processor/custom.py
View file @
006693ed
...
@@ -33,7 +33,7 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
...
@@ -33,7 +33,7 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
------------------------------------------------------------
"""
"""
from
typing
import
Optional
from
typing
import
Any
import
torch
import
torch
...
@@ -50,6 +50,16 @@ from vllm.v1.sample.logits_processor.builtin import process_dict_updates
...
@@ -50,6 +50,16 @@ from vllm.v1.sample.logits_processor.builtin import process_dict_updates
class
DummyLogitsProcessor
(
LogitsProcessor
):
class
DummyLogitsProcessor
(
LogitsProcessor
):
"""Fake logit processor to support unit testing and examples"""
"""Fake logit processor to support unit testing and examples"""
@
classmethod
def
validate_params
(
cls
,
params
:
SamplingParams
):
target_token
:
Any
|
None
=
params
.
extra_args
and
params
.
extra_args
.
get
(
"target_token"
)
if
target_token
is
not
None
and
not
isinstance
(
target_token
,
int
):
raise
ValueError
(
f
"target_token value
{
target_token
}
{
type
(
target_token
)
}
is not int"
)
def
__init__
(
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
is_pin_memory
:
bool
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
is_pin_memory
:
bool
):
):
...
@@ -58,15 +68,18 @@ class DummyLogitsProcessor(LogitsProcessor):
...
@@ -58,15 +68,18 @@ class DummyLogitsProcessor(LogitsProcessor):
def
is_argmax_invariant
(
self
)
->
bool
:
def
is_argmax_invariant
(
self
)
->
bool
:
return
False
return
False
def
update_state
(
self
,
batch_update
:
Optional
[
BatchUpdate
]):
def
update_state
(
self
,
batch_update
:
BatchUpdate
|
None
):
def
extract_extra_arg
(
params
:
SamplingParams
)
->
int
|
None
:
self
.
validate_params
(
params
)
return
params
.
extra_args
and
params
.
extra_args
.
get
(
"target_token"
)
process_dict_updates
(
process_dict_updates
(
self
.
req_info
,
self
.
req_info
,
batch_update
,
batch_update
,
# This function returns the LP's per-request state based on the
# This function returns the LP's per-request state based on the
# request details, or None if this LP does not apply to the
# request details, or None if this LP does not apply to the
# request.
# request.
lambda
params
,
_
,
__
:
params
.
extra_args
lambda
params
,
_
,
__
:
extract_extra_arg
(
params
),
and
(
params
.
extra_args
.
get
(
"target_token"
)),
)
)
def
apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
examples/offline_inference/logits_processor/custom_req.py
View file @
006693ed
...
@@ -39,7 +39,7 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
...
@@ -39,7 +39,7 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
------------------------------------------------------------
"""
"""
from
typing
import
Any
,
Optional
from
typing
import
Any
import
torch
import
torch
...
@@ -76,13 +76,21 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
...
@@ -76,13 +76,21 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""
batch-level logits processor"""
@
classmethod
def
validate_params
(
cls
,
params
:
SamplingParams
):
target_token
:
Any
|
None
=
params
.
extra_args
and
params
.
extra_args
.
get
(
"target_token"
)
if
target_token
is
not
None
and
not
isinstance
(
target_token
,
int
):
raise
ValueError
(
f
"target_token value
{
target_token
}
is not int"
)
def
is_argmax_invariant
(
self
)
->
bool
:
def
is_argmax_invariant
(
self
)
->
bool
:
return
False
return
False
def
new_req_logits_processor
(
def
new_req_logits_processor
(
self
,
self
,
params
:
SamplingParams
,
params
:
SamplingParams
,
)
->
Optional
[
RequestLogitsProcessor
]
:
)
->
RequestLogitsProcessor
|
None
:
"""This method returns a new request-level logits processor, customized
"""This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request.
to the `target_token` value associated with a particular request.
...
@@ -96,18 +104,11 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
...
@@ -96,18 +104,11 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
Returns:
Returns:
`Callable` request logits processor, or None
`Callable` request logits processor, or None
"""
"""
target_token
:
Optional
[
Any
]
=
params
.
extra_args
and
params
.
extra_args
.
get
(
target_token
:
Any
|
None
=
params
.
extra_args
and
params
.
extra_args
.
get
(
"target_token"
"target_token"
)
)
if
target_token
is
None
:
if
target_token
is
None
:
return
None
return
None
if
not
isinstance
(
target_token
,
int
):
logger
.
warning
(
"target_token value %s is not int; not applying logits"
" processor to request."
,
target_token
,
)
return
None
return
DummyPerReqLogitsProcessor
(
target_token
)
return
DummyPerReqLogitsProcessor
(
target_token
)
...
...
examples/offline_inference/logits_processor/custom_req_init.py
View file @
006693ed
...
@@ -41,8 +41,6 @@ which indicates that the logits processor is running. However, on a non-"cuda"
...
@@ -41,8 +41,6 @@ which indicates that the logits processor is running. However, on a non-"cuda"
device, the first and third requests would not repeat the same token.
device, the first and third requests would not repeat the same token.
"""
"""
from
typing
import
Optional
import
torch
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
...
@@ -79,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
...
@@ -79,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""
info about the device type"""
@
classmethod
def
validate_params
(
cls
,
params
:
SamplingParams
):
target_token
=
params
.
extra_args
and
params
.
extra_args
.
get
(
"target_token"
)
if
target_token
is
not
None
and
not
isinstance
(
target_token
,
int
):
raise
ValueError
(
f
"`target_token` has to be an integer, got
{
target_token
}
."
)
def
__init__
(
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
is_pin_memory
:
bool
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
is_pin_memory
:
bool
):
):
...
@@ -91,7 +97,7 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
...
@@ -91,7 +97,7 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
def
new_req_logits_processor
(
def
new_req_logits_processor
(
self
,
self
,
params
:
SamplingParams
,
params
:
SamplingParams
,
)
->
Optional
[
RequestLogitsProcessor
]
:
)
->
RequestLogitsProcessor
|
None
:
"""This method returns a new request-level logits processor, customized
"""This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request.
to the `target_token` value associated with a particular request.
...
@@ -115,13 +121,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
...
@@ -115,13 +121,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
is
None
is
None
):
):
return
None
return
None
if
not
isinstance
(
target_token
,
int
):
logger
.
warning
(
"target_token value %s is not int; not applying logits"
" processor to request."
,
target_token
,
)
return
None
return
DummyPerReqLogitsProcessor
(
target_token
)
return
DummyPerReqLogitsProcessor
(
target_token
)
...
...
examples/offline_inference/lora_with_quantization_inference.py
View file @
006693ed
...
@@ -8,7 +8,6 @@ Requires HuggingFace credentials for access.
...
@@ -8,7 +8,6 @@ Requires HuggingFace credentials for access.
"""
"""
import
gc
import
gc
from
typing
import
Optional
import
torch
import
torch
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
...
@@ -19,7 +18,7 @@ from vllm.lora.request import LoRARequest
...
@@ -19,7 +18,7 @@ from vllm.lora.request import LoRARequest
def
create_test_prompts
(
def
create_test_prompts
(
lora_path
:
str
,
lora_path
:
str
,
)
->
list
[
tuple
[
str
,
SamplingParams
,
Optional
[
LoRARequest
]
]]:
)
->
list
[
tuple
[
str
,
SamplingParams
,
LoRARequest
|
None
]]:
return
[
return
[
# this is an example of using quantization without LoRA
# this is an example of using quantization without LoRA
(
(
...
@@ -56,7 +55,7 @@ def create_test_prompts(
...
@@ -56,7 +55,7 @@ def create_test_prompts(
def
process_requests
(
def
process_requests
(
engine
:
LLMEngine
,
engine
:
LLMEngine
,
test_prompts
:
list
[
tuple
[
str
,
SamplingParams
,
Optional
[
LoRARequest
]
]],
test_prompts
:
list
[
tuple
[
str
,
SamplingParams
,
LoRARequest
|
None
]],
):
):
"""Continuously process a list of prompts and handle the outputs."""
"""Continuously process a list of prompts and handle the outputs."""
request_id
=
0
request_id
=
0
...
@@ -78,7 +77,7 @@ def process_requests(
...
@@ -78,7 +77,7 @@ def process_requests(
def
initialize_engine
(
def
initialize_engine
(
model
:
str
,
quantization
:
str
,
lora_repo
:
Optional
[
str
]
model
:
str
,
quantization
:
str
,
lora_repo
:
str
|
None
)
->
LLMEngine
:
)
->
LLMEngine
:
"""Initialize the LLMEngine."""
"""Initialize the LLMEngine."""
...
...
examples/offline_inference/mlpspeculator.py
View file @
006693ed
...
@@ -4,8 +4,7 @@
...
@@ -4,8 +4,7 @@
This file demonstrates the usage of text generation with an LLM model,
This file demonstrates the usage of text generation with an LLM model,
comparing the performance with and without speculative decoding.
comparing the performance with and without speculative decoding.
Note that still not support `v1`:
Note that this example is out of date and not supported in vLLM v1.
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
"""
"""
import
gc
import
gc
...
...
examples/offline_inference/multilora_inference.py
View file @
006693ed
...
@@ -7,8 +7,6 @@ for offline inference.
...
@@ -7,8 +7,6 @@ for offline inference.
Requires HuggingFace credentials for access to Llama2.
Requires HuggingFace credentials for access to Llama2.
"""
"""
from
typing
import
Optional
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
vllm
import
EngineArgs
,
LLMEngine
,
RequestOutput
,
SamplingParams
from
vllm
import
EngineArgs
,
LLMEngine
,
RequestOutput
,
SamplingParams
...
@@ -17,7 +15,7 @@ from vllm.lora.request import LoRARequest
...
@@ -17,7 +15,7 @@ from vllm.lora.request import LoRARequest
def
create_test_prompts
(
def
create_test_prompts
(
lora_path
:
str
,
lora_path
:
str
,
)
->
list
[
tuple
[
str
,
SamplingParams
,
Optional
[
LoRARequest
]
]]:
)
->
list
[
tuple
[
str
,
SamplingParams
,
LoRARequest
|
None
]]:
"""Create a list of test prompts with their sampling parameters.
"""Create a list of test prompts with their sampling parameters.
2 requests for base model, 4 requests for the LoRA. We define 2
2 requests for base model, 4 requests for the LoRA. We define 2
...
@@ -68,7 +66,7 @@ def create_test_prompts(
...
@@ -68,7 +66,7 @@ def create_test_prompts(
def
process_requests
(
def
process_requests
(
engine
:
LLMEngine
,
engine
:
LLMEngine
,
test_prompts
:
list
[
tuple
[
str
,
SamplingParams
,
Optional
[
LoRARequest
]
]],
test_prompts
:
list
[
tuple
[
str
,
SamplingParams
,
LoRARequest
|
None
]],
):
):
"""Continuously process a list of prompts and handle the outputs."""
"""Continuously process a list of prompts and handle the outputs."""
request_id
=
0
request_id
=
0
...
...
examples/offline_inference/openai_batch/README.md
View file @
006693ed
...
@@ -152,7 +152,9 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_
...
@@ -152,7 +152,9 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_
"""
"""
try
:
try
:
url
=
s3_client
.
generate_presigned_url
(
url
=
s3_client
.
generate_presigned_url
(
ClientMethod
=
client_method
,
Params
=
method_parameters
,
ExpiresIn
=
expires_in
ClientMethod
=
client_method
,
Params
=
method_parameters
,
ExpiresIn
=
expires_in
,
)
)
except
ClientError
:
except
ClientError
:
raise
raise
...
@@ -161,10 +163,16 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_
...
@@ -161,10 +163,16 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_
s3_client
=
boto3
.
client
(
"s3"
)
s3_client
=
boto3
.
client
(
"s3"
)
input_url
=
generate_presigned_url
(
input_url
=
generate_presigned_url
(
s3_client
,
"get_object"
,
{
"Bucket"
:
"MY_BUCKET"
,
"Key"
:
"MY_INPUT_FILE.jsonl"
},
3600
s3_client
,
"get_object"
,
{
"Bucket"
:
"MY_BUCKET"
,
"Key"
:
"MY_INPUT_FILE.jsonl"
},
expires_in
=
3600
,
)
)
output_url
=
generate_presigned_url
(
output_url
=
generate_presigned_url
(
s3_client
,
"put_object"
,
{
"Bucket"
:
"MY_BUCKET"
,
"Key"
:
"MY_OUTPUT_FILE.jsonl"
},
3600
s3_client
,
"put_object"
,
{
"Bucket"
:
"MY_BUCKET"
,
"Key"
:
"MY_OUTPUT_FILE.jsonl"
},
expires_in
=
3600
,
)
)
print
(
f
"
{
input_url
=
}
"
)
print
(
f
"
{
input_url
=
}
"
)
print
(
f
"
{
output_url
=
}
"
)
print
(
f
"
{
output_url
=
}
"
)
...
...
examples/offline_inference/pooling/README.md
View file @
006693ed
...
@@ -14,7 +14,7 @@ python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_na
...
@@ -14,7 +14,7 @@ python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_na
## Embed jina_embeddings_v3 usage
## Embed jina_embeddings_v3 usage
Only text matching task is supported for now. See
<
gh-pr:
16120>
Only text matching task is supported for now. See
<
https://github.com/vllm-project/vllm/pull/
16120>
```
bash
```
bash
python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
...
@@ -26,12 +26,30 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
...
@@ -26,12 +26,30 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
python examples/offline_inference/pooling/embed_matryoshka_fy.py
python examples/offline_inference/pooling/embed_matryoshka_fy.py
```
```
## Multi vector retrieval usage
```
bash
python examples/offline_inference/pooling/multi_vector_retrieval.py
```
## Named Entity Recognition (NER) usage
## Named Entity Recognition (NER) usage
```
bash
```
bash
python examples/offline_inference/pooling/ner.py
python examples/offline_inference/pooling/ner.py
```
```
## Prithvi Geospatial MAE usage
```
bash
python examples/offline_inference/pooling/prithvi_geospatial_mae.py
```
## IO Processor Plugins for Prithvi Geospatial MAE
```
bash
python examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py
```
## Qwen3 reranker usage
## Qwen3 reranker usage
```
bash
```
bash
...
...
examples/offline_inference/pooling/embed_jina_embeddings_v3.py
View file @
006693ed
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
argparse
import
Namespace
from
argparse
import
Namespace
from
vllm
import
LLM
,
EngineArgs
from
vllm
import
LLM
,
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
def
parse_args
():
def
parse_args
():
...
...
examples/offline_inference/pooling/embed_matryoshka_fy.py
View file @
006693ed
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
argparse
import
Namespace
from
argparse
import
Namespace
from
vllm
import
LLM
,
EngineArgs
,
PoolingParams
from
vllm
import
LLM
,
EngineArgs
,
PoolingParams
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
def
parse_args
():
def
parse_args
():
...
...
examples/offline_inference/pooling/multi_vector_retrieval.py
0 → 100644
View file @
006693ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
argparse
import
Namespace
from
vllm
import
LLM
,
EngineArgs
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
def
parse_args
():
parser
=
FlexibleArgumentParser
()
parser
=
EngineArgs
.
add_cli_args
(
parser
)
# Set example specific arguments
parser
.
set_defaults
(
model
=
"BAAI/bge-m3"
,
runner
=
"pooling"
,
enforce_eager
=
True
,
)
return
parser
.
parse_args
()
def
main
(
args
:
Namespace
):
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create an LLM.
# You should pass runner="pooling" for embedding models
llm
=
LLM
(
**
vars
(
args
))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs
=
llm
.
embed
(
prompts
)
# Print the outputs.
print
(
"
\n
Generated Outputs:
\n
"
+
"-"
*
60
)
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
embeds
=
output
.
outputs
.
embedding
print
(
len
(
embeds
))
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
outputs
=
llm
.
encode
(
prompts
,
pooling_task
=
"token_embed"
)
# Print the outputs.
print
(
"
\n
Generated Outputs:
\n
"
+
"-"
*
60
)
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
multi_vector
=
output
.
outputs
.
data
print
(
multi_vector
.
shape
)
if
__name__
==
"__main__"
:
args
=
parse_args
()
main
(
args
)
examples/offline_inference/pooling/ner.py
View file @
006693ed
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
from
argparse
import
Namespace
from
argparse
import
Namespace
from
vllm
import
LLM
,
EngineArgs
from
vllm
import
LLM
,
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
def
parse_args
():
def
parse_args
():
...
@@ -33,7 +33,7 @@ def main(args: Namespace):
...
@@ -33,7 +33,7 @@ def main(args: Namespace):
label_map
=
llm
.
llm_engine
.
vllm_config
.
model_config
.
hf_config
.
id2label
label_map
=
llm
.
llm_engine
.
vllm_config
.
model_config
.
hf_config
.
id2label
# Run inference
# Run inference
outputs
=
llm
.
encode
(
prompts
)
outputs
=
llm
.
encode
(
prompts
,
pooling_task
=
"token_classify"
)
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
logits
=
output
.
outputs
.
data
logits
=
output
.
outputs
.
data
...
...
examples/offline_inference/prithvi_geospatial_mae.py
→
examples/offline_inference/
pooling/
prithvi_geospatial_mae.py
View file @
006693ed
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
argparse
import
argparse
import
datetime
import
datetime
import
os
import
os
from
typing
import
Union
import
albumentations
import
albumentations
import
numpy
as
np
import
numpy
as
np
...
@@ -50,6 +49,7 @@ class PrithviMAE:
...
@@ -50,6 +49,7 @@ class PrithviMAE:
dtype
=
"float16"
,
dtype
=
"float16"
,
enforce_eager
=
True
,
enforce_eager
=
True
,
model_impl
=
"terratorch"
,
model_impl
=
"terratorch"
,
enable_mm_embeds
=
True
,
)
)
def
run
(
self
,
input_data
,
location_coords
):
def
run
(
self
,
input_data
,
location_coords
):
...
@@ -64,7 +64,7 @@ class PrithviMAE:
...
@@ -64,7 +64,7 @@ class PrithviMAE:
}
}
prompt
=
{
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
mm_data
}
prompt
=
{
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
mm_data
}
outputs
=
self
.
model
.
encode
(
prompt
,
use_tqdm
=
False
)
outputs
=
self
.
model
.
encode
(
prompt
,
pooling_task
=
"plugin"
,
use_tqdm
=
False
)
return
outputs
[
0
].
outputs
.
data
return
outputs
[
0
].
outputs
.
data
...
@@ -160,7 +160,7 @@ def load_example(
...
@@ -160,7 +160,7 @@ def load_example(
file_paths
:
list
[
str
],
file_paths
:
list
[
str
],
mean
:
list
[
float
]
=
None
,
mean
:
list
[
float
]
=
None
,
std
:
list
[
float
]
=
None
,
std
:
list
[
float
]
=
None
,
indices
:
Union
[
list
[
int
]
,
None
]
=
None
,
indices
:
list
[
int
]
|
None
=
None
,
):
):
"""Build an input example by loading images in *file_paths*.
"""Build an input example by loading images in *file_paths*.
...
...
examples/offline_inference/prithvi_geospatial_mae_io_processor.py
→
examples/offline_inference/
pooling/
prithvi_geospatial_mae_io_processor.py
View file @
006693ed
...
@@ -6,14 +6,14 @@ import os
...
@@ -6,14 +6,14 @@ import os
import
torch
import
torch
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.pooling_params
import
PoolingParams
# This example shows how to perform an offline inference that generates
# This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# image as input, process it using the multimodal data processor, and
# perform inference.
# perform inference.
# Requirement - install plugin at:
# Requirements:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
def
main
():
def
main
():
...
@@ -36,15 +36,12 @@ def main():
...
@@ -36,15 +36,12 @@ def main():
# to avoid the model going OOM.
# to avoid the model going OOM.
# The maximum number depends on the available GPU memory
# The maximum number depends on the available GPU memory
max_num_seqs
=
32
,
max_num_seqs
=
32
,
io_processor_plugin
=
"
prithvi_to_tiff
"
,
io_processor_plugin
=
"
terratorch_segmentation
"
,
model_impl
=
"terratorch"
,
model_impl
=
"terratorch"
,
enable_mm_embeds
=
True
,
)
)
pooling_params
=
PoolingParams
(
task
=
"encode"
,
softmax
=
False
)
pooler_output
=
llm
.
encode
(
img_prompt
,
pooling_task
=
"plugin"
)
pooler_output
=
llm
.
encode
(
img_prompt
,
pooling_params
=
pooling_params
,
)
output
=
pooler_output
[
0
].
outputs
output
=
pooler_output
[
0
].
outputs
print
(
output
)
print
(
output
)
...
...
Prev
1
…
18
19
20
21
22
23
24
25
26
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment