Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
"""Read prompts from prefill_output.txt"""
prompts = []
try:
with open("prefill_output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
return prompts
except FileNotFoundError:
print("Error: prefill_output.txt file not found")
exit(-1)
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
parser = argparse.ArgumentParser()
parser.add_argument(
"--simulate-failure", action="store_true", help="Simulate KV load failure."
)
parser.add_argument(
"--async-load", action="store_true", help="Simulate async KV load"
)
args = parser.parse_args()
if args.simulate_failure:
ktc = KVTransferConfig(
kv_connector="RogueSharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
"async_load": args.async_load,
},
kv_connector_module_path="rogue_shared_storage_connector",
)
out_file = (
"async_decode_recovered_output.txt"
if args.async_load
else "sync_decode_recovered_output.txt"
)
else:
ktc = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
},
)
out_file = "decode_output.txt"
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=ktc,
)
outputs = llm.generate(prompts, sampling_params)
sep_str = "-" * 30
with open(out_file, "w", encoding="utf-8") as f:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
print(out_str)
print(sep_str)
f.write(out_str)
f.write(sep_str)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
context = "Hi " * 1000
context2 = "Hey " * 500
return [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to prefill_output.txt
with open("prefill_output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
SharedStorageConnector,
SharedStorageConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
@dataclass
class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata):
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)
@classmethod
def from_base(cls, base: SharedStorageConnectorMetadata):
return cls(requests=base.requests)
class RogueSharedStorageConnector(SharedStorageConnector):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
"async_load", False
)
self._invalid_block_ids: set = None
self._seen_requests: set = set()
self._req_to_block_ids: dict[str, list[int]] = dict()
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata)
index, failed_request = next(
(
(i, x)
for i, x in enumerate(connector_metadata.requests)
if not x.is_store
),
(None, None),
)
if index is not None:
del connector_metadata.requests[index]
self._invalid_block_ids = set(
(
failed_request.slot_mapping[:: self._block_size] // self._block_size
).tolist()
)
logger.info(
"Simulating failure to load all KV blocks for the "
"first load request. Total blocks: %d",
len(self._invalid_block_ids),
)
super().bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
self._invalid_block_ids = None
super().clear_connector_metadata()
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
if self._async_load and forward_context.attn_metadata is None:
# Bypass sanity check in super().start_load_kv
forward_context.attn_metadata = "None"
super().start_load_kv(forward_context, **kwargs)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
if self._async_load:
meta = self._get_connector_metadata()
assert isinstance(meta, RogueSharedStorageConnectorMetadata)
if meta.req_to_block_ids:
return None, set(meta.req_to_block_ids)
return None, None
def get_block_ids_with_load_errors(self) -> set[int]:
return self._invalid_block_ids
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int, bool]:
if request.request_id in self._seen_requests:
return 0, False
self._seen_requests.add(request.request_id)
num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
return num_tokens, self._async_load and num_tokens > 0
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
super().update_state_after_alloc(request, blocks, num_external_tokens)
if num_external_tokens > 0:
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]
def build_connector_meta(
self,
scheduler_output: "SchedulerOutput",
) -> KVConnectorMetadata:
if not self._async_load:
base = super().build_connector_meta(scheduler_output)
meta = RogueSharedStorageConnectorMetadata.from_base(base)
else:
meta = RogueSharedStorageConnectorMetadata()
if self._requests_need_load:
for req_id, request in self._requests_need_load.items():
meta.add_request(
token_ids=request.prompt_token_ids,
block_ids=self._req_to_block_ids[req_id],
block_size=self._block_size,
is_store=False,
mm_hashes=[],
)
# Clear state
self._requests_need_load.clear()
meta.req_to_block_ids = self._req_to_block_ids
self._req_to_block_ids = dict()
return meta
#!/bin/bash
# Constants
SHARED_STORAGE_DIR="local_storage"
PREFILL_OUTPUT="prefill_output.txt"
DECODE_OUTPUT="decode_output.txt"
SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt"
ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"
# Cleanup
rm -rf "$SHARED_STORAGE_DIR"
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
# Run inference examples
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load
# Compare outputs
if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then
echo "❌ Outputs differ: sync recovery failed."
diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"
exit 1
fi
if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then
echo "❌ Outputs differ: async recovery failed."
diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
exit 1
fi
echo "✅ Outputs match: recovery successful."
...@@ -8,7 +8,7 @@ for processing prompts with various sampling parameters. ...@@ -8,7 +8,7 @@ for processing prompts with various sampling parameters.
import argparse import argparse
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
def create_test_prompts() -> list[tuple[str, SamplingParams]]: def create_test_prompts() -> list[tuple[str, SamplingParams]]:
......
...@@ -11,7 +11,7 @@ python save_sharded_state.py \ ...@@ -11,7 +11,7 @@ python save_sharded_state.py \
--model /path/to/load \ --model /path/to/load \
--quantization deepspeedfp \ --quantization deepspeedfp \
--tensor-parallel-size 8 \ --tensor-parallel-size 8 \
--output /path/to/save/sharded/modele --output /path/to/save/sharded/model
python load_sharded_state.py \ python load_sharded_state.py \
--model /path/to/saved/sharded/model \ --model /path/to/saved/sharded/model \
...@@ -25,7 +25,7 @@ python load_sharded_state.py \ ...@@ -25,7 +25,7 @@ python load_sharded_state.py \
import dataclasses import dataclasses
from vllm import LLM, EngineArgs, SamplingParams from vllm import LLM, EngineArgs, SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args(): def parse_args():
......
...@@ -33,7 +33,7 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the' ...@@ -33,7 +33,7 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------ ------------------------------------------------------------
""" """
from typing import Optional from typing import Any
import torch import torch
...@@ -50,6 +50,16 @@ from vllm.v1.sample.logits_processor.builtin import process_dict_updates ...@@ -50,6 +50,16 @@ from vllm.v1.sample.logits_processor.builtin import process_dict_updates
class DummyLogitsProcessor(LogitsProcessor): class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples""" """Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)
def __init__( def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
): ):
...@@ -58,15 +68,18 @@ class DummyLogitsProcessor(LogitsProcessor): ...@@ -58,15 +68,18 @@ class DummyLogitsProcessor(LogitsProcessor):
def is_argmax_invariant(self) -> bool: def is_argmax_invariant(self) -> bool:
return False return False
def update_state(self, batch_update: Optional[BatchUpdate]): def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")
process_dict_updates( process_dict_updates(
self.req_info, self.req_info,
batch_update, batch_update,
# This function returns the LP's per-request state based on the # This function returns the LP's per-request state based on the
# request details, or None if this LP does not apply to the # request details, or None if this LP does not apply to the
# request. # request.
lambda params, _, __: params.extra_args lambda params, _, __: extract_extra_arg(params),
and (params.extra_args.get("target_token")),
) )
def apply(self, logits: torch.Tensor) -> torch.Tensor: def apply(self, logits: torch.Tensor) -> torch.Tensor:
......
...@@ -39,7 +39,7 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the' ...@@ -39,7 +39,7 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------ ------------------------------------------------------------
""" """
from typing import Any, Optional from typing import Any
import torch import torch
...@@ -76,13 +76,21 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): ...@@ -76,13 +76,21 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a """Example of wrapping a fake request-level logit processor to create a
batch-level logits processor""" batch-level logits processor"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")
def is_argmax_invariant(self) -> bool: def is_argmax_invariant(self) -> bool:
return False return False
def new_req_logits_processor( def new_req_logits_processor(
self, self,
params: SamplingParams, params: SamplingParams,
) -> Optional[RequestLogitsProcessor]: ) -> RequestLogitsProcessor | None:
"""This method returns a new request-level logits processor, customized """This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request. to the `target_token` value associated with a particular request.
...@@ -96,18 +104,11 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): ...@@ -96,18 +104,11 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
Returns: Returns:
`Callable` request logits processor, or None `Callable` request logits processor, or None
""" """
target_token: Optional[Any] = params.extra_args and params.extra_args.get( target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token" "target_token"
) )
if target_token is None: if target_token is None:
return None return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token) return DummyPerReqLogitsProcessor(target_token)
......
...@@ -41,8 +41,6 @@ which indicates that the logits processor is running. However, on a non-"cuda" ...@@ -41,8 +41,6 @@ which indicates that the logits processor is running. However, on a non-"cuda"
device, the first and third requests would not repeat the same token. device, the first and third requests would not repeat the same token.
""" """
from typing import Optional
import torch import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
...@@ -79,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): ...@@ -79,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize """Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type""" info about the device type"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token = params.extra_args and params.extra_args.get("target_token")
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"`target_token` has to be an integer, got {target_token}."
)
def __init__( def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
): ):
...@@ -91,7 +97,7 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): ...@@ -91,7 +97,7 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
def new_req_logits_processor( def new_req_logits_processor(
self, self,
params: SamplingParams, params: SamplingParams,
) -> Optional[RequestLogitsProcessor]: ) -> RequestLogitsProcessor | None:
"""This method returns a new request-level logits processor, customized """This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request. to the `target_token` value associated with a particular request.
...@@ -115,13 +121,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): ...@@ -115,13 +121,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
is None is None
): ):
return None return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token) return DummyPerReqLogitsProcessor(target_token)
......
...@@ -8,7 +8,6 @@ Requires HuggingFace credentials for access. ...@@ -8,7 +8,6 @@ Requires HuggingFace credentials for access.
""" """
import gc import gc
from typing import Optional
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
...@@ -19,7 +18,7 @@ from vllm.lora.request import LoRARequest ...@@ -19,7 +18,7 @@ from vllm.lora.request import LoRARequest
def create_test_prompts( def create_test_prompts(
lora_path: str, lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: ) -> list[tuple[str, SamplingParams, LoRARequest | None]]:
return [ return [
# this is an example of using quantization without LoRA # this is an example of using quantization without LoRA
( (
...@@ -56,7 +55,7 @@ def create_test_prompts( ...@@ -56,7 +55,7 @@ def create_test_prompts(
def process_requests( def process_requests(
engine: LLMEngine, engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]], test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]],
): ):
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
...@@ -78,7 +77,7 @@ def process_requests( ...@@ -78,7 +77,7 @@ def process_requests(
def initialize_engine( def initialize_engine(
model: str, quantization: str, lora_repo: Optional[str] model: str, quantization: str, lora_repo: str | None
) -> LLMEngine: ) -> LLMEngine:
"""Initialize the LLMEngine.""" """Initialize the LLMEngine."""
......
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
This file demonstrates the usage of text generation with an LLM model, This file demonstrates the usage of text generation with an LLM model,
comparing the performance with and without speculative decoding. comparing the performance with and without speculative decoding.
Note that still not support `v1`: Note that this example is out of date and not supported in vLLM v1.
VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py
""" """
import gc import gc
......
...@@ -7,8 +7,6 @@ for offline inference. ...@@ -7,8 +7,6 @@ for offline inference.
Requires HuggingFace credentials for access to Llama2. Requires HuggingFace credentials for access to Llama2.
""" """
from typing import Optional
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
...@@ -17,7 +15,7 @@ from vllm.lora.request import LoRARequest ...@@ -17,7 +15,7 @@ from vllm.lora.request import LoRARequest
def create_test_prompts( def create_test_prompts(
lora_path: str, lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: ) -> list[tuple[str, SamplingParams, LoRARequest | None]]:
"""Create a list of test prompts with their sampling parameters. """Create a list of test prompts with their sampling parameters.
2 requests for base model, 4 requests for the LoRA. We define 2 2 requests for base model, 4 requests for the LoRA. We define 2
...@@ -68,7 +66,7 @@ def create_test_prompts( ...@@ -68,7 +66,7 @@ def create_test_prompts(
def process_requests( def process_requests(
engine: LLMEngine, engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]], test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]],
): ):
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
......
...@@ -152,7 +152,9 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_ ...@@ -152,7 +152,9 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_
""" """
try: try:
url = s3_client.generate_presigned_url( url = s3_client.generate_presigned_url(
ClientMethod=client_method, Params=method_parameters, ExpiresIn=expires_in ClientMethod=client_method,
Params=method_parameters,
ExpiresIn=expires_in,
) )
except ClientError: except ClientError:
raise raise
...@@ -161,10 +163,16 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_ ...@@ -161,10 +163,16 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_
s3_client = boto3.client("s3") s3_client = boto3.client("s3")
input_url = generate_presigned_url( input_url = generate_presigned_url(
s3_client, "get_object", {"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"}, 3600 s3_client,
"get_object",
{"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"},
expires_in=3600,
) )
output_url = generate_presigned_url( output_url = generate_presigned_url(
s3_client, "put_object", {"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"}, 3600 s3_client,
"put_object",
{"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"},
expires_in=3600,
) )
print(f"{input_url=}") print(f"{input_url=}")
print(f"{output_url=}") print(f"{output_url=}")
......
...@@ -14,7 +14,7 @@ python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_na ...@@ -14,7 +14,7 @@ python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_na
## Embed jina_embeddings_v3 usage ## Embed jina_embeddings_v3 usage
Only text matching task is supported for now. See <gh-pr:16120> Only text matching task is supported for now. See <https://github.com/vllm-project/vllm/pull/16120>
```bash ```bash
python examples/offline_inference/pooling/embed_jina_embeddings_v3.py python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
...@@ -26,12 +26,30 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py ...@@ -26,12 +26,30 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
python examples/offline_inference/pooling/embed_matryoshka_fy.py python examples/offline_inference/pooling/embed_matryoshka_fy.py
``` ```
## Multi vector retrieval usage
```bash
python examples/offline_inference/pooling/multi_vector_retrieval.py
```
## Named Entity Recognition (NER) usage ## Named Entity Recognition (NER) usage
```bash ```bash
python examples/offline_inference/pooling/ner.py python examples/offline_inference/pooling/ner.py
``` ```
## Prithvi Geospatial MAE usage
```bash
python examples/offline_inference/pooling/prithvi_geospatial_mae.py
```
## IO Processor Plugins for Prithvi Geospatial MAE
```bash
python examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py
```
## Qwen3 reranker usage ## Qwen3 reranker usage
```bash ```bash
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from argparse import Namespace from argparse import Namespace
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args(): def parse_args():
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from argparse import Namespace from argparse import Namespace
from vllm import LLM, EngineArgs, PoolingParams from vllm import LLM, EngineArgs, PoolingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args(): def parse_args():
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="BAAI/bge-m3",
runner="pooling",
enforce_eager=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
# You should pass runner="pooling" for embedding models
llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.embed(prompts)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
print(len(embeds))
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
outputs = llm.encode(prompts, pooling_task="token_embed")
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
multi_vector = output.outputs.data
print(multi_vector.shape)
if __name__ == "__main__":
args = parse_args()
main(args)
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from argparse import Namespace from argparse import Namespace
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args(): def parse_args():
...@@ -33,7 +33,7 @@ def main(args: Namespace): ...@@ -33,7 +33,7 @@ def main(args: Namespace):
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
# Run inference # Run inference
outputs = llm.encode(prompts) outputs = llm.encode(prompts, pooling_task="token_classify")
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
logits = output.outputs.data logits = output.outputs.data
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import argparse import argparse
import datetime import datetime
import os import os
from typing import Union
import albumentations import albumentations
import numpy as np import numpy as np
...@@ -50,6 +49,7 @@ class PrithviMAE: ...@@ -50,6 +49,7 @@ class PrithviMAE:
dtype="float16", dtype="float16",
enforce_eager=True, enforce_eager=True,
model_impl="terratorch", model_impl="terratorch",
enable_mm_embeds=True,
) )
def run(self, input_data, location_coords): def run(self, input_data, location_coords):
...@@ -64,7 +64,7 @@ class PrithviMAE: ...@@ -64,7 +64,7 @@ class PrithviMAE:
} }
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False) outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)
return outputs[0].outputs.data return outputs[0].outputs.data
...@@ -160,7 +160,7 @@ def load_example( ...@@ -160,7 +160,7 @@ def load_example(
file_paths: list[str], file_paths: list[str],
mean: list[float] = None, mean: list[float] = None,
std: list[float] = None, std: list[float] = None,
indices: Union[list[int], None] = None, indices: list[int] | None = None,
): ):
"""Build an input example by loading images in *file_paths*. """Build an input example by loading images in *file_paths*.
......
...@@ -6,14 +6,14 @@ import os ...@@ -6,14 +6,14 @@ import os
import torch import torch
from vllm import LLM from vllm import LLM
from vllm.pooling_params import PoolingParams
# This example shows how to perform an offline inference that generates # This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff # multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and # image as input, process it using the multimodal data processor, and
# perform inference. # perform inference.
# Requirement - install plugin at: # Requirements:
# https://github.com/christian-pinto/prithvi_io_processor_plugin # - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
def main(): def main():
...@@ -36,15 +36,12 @@ def main(): ...@@ -36,15 +36,12 @@ def main():
# to avoid the model going OOM. # to avoid the model going OOM.
# The maximum number depends on the available GPU memory # The maximum number depends on the available GPU memory
max_num_seqs=32, max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff", io_processor_plugin="terratorch_segmentation",
model_impl="terratorch", model_impl="terratorch",
enable_mm_embeds=True,
) )
pooling_params = PoolingParams(task="encode", softmax=False) pooler_output = llm.encode(img_prompt, pooling_task="plugin")
pooler_output = llm.encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs output = pooler_output[0].outputs
print(output) print(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment