Unverified Commit e254497b authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)

parent 4e121310
from vllm import LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
responses = client.embeddings.create(input=[
"Hello my name is",
"The best thing about vLLM is that it supports many different models"
],
model=model)
for data in responses.data:
print(data.embedding) # list of float of len 4096
...@@ -19,12 +19,15 @@ pytest-forked ...@@ -19,12 +19,15 @@ pytest-forked
pytest-asyncio pytest-asyncio
pytest-rerunfailures pytest-rerunfailures
pytest-shard pytest-shard
httpx
# testing utils
awscli
einops # required for MPT einops # required for MPT
httpx
peft
requests requests
ray ray
peft sentence-transformers # required for embedding
awscli
# Benchmarking # Benchmarking
aiohttp aiohttp
......
...@@ -133,6 +133,10 @@ _VISION_LANGUAGE_MODELS = { ...@@ -133,6 +133,10 @@ _VISION_LANGUAGE_MODELS = {
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
} }
_EMBEDDING_MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
class HfRunner: class HfRunner:
...@@ -145,14 +149,7 @@ class HfRunner: ...@@ -145,14 +149,7 @@ class HfRunner:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name self.model_name = model_name
if model_name not in _VISION_LANGUAGE_MODELS: if model_name in _VISION_LANGUAGE_MODELS:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = None
else:
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
...@@ -162,6 +159,20 @@ class HfRunner: ...@@ -162,6 +159,20 @@ class HfRunner:
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
elif model_name in _EMBEDDING_MODELS:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device="cpu",
).to(dtype=torch_dtype).cuda()
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = None
if tokenizer_name is None: if tokenizer_name is None:
tokenizer_name = model_name tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
...@@ -334,6 +345,9 @@ class HfRunner: ...@@ -334,6 +345,9 @@ class HfRunner:
return [(output_ids, output_str, output_logprobs) return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs] for output_ids, output_str, output_logprobs in outputs]
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
return self.model.encode(prompts)
def __del__(self): def __del__(self):
del self.model del self.model
cleanup() cleanup()
...@@ -459,6 +473,14 @@ class VllmRunner: ...@@ -459,6 +473,14 @@ class VllmRunner:
outputs = self.generate(prompts, beam_search_params) outputs = self.generate(prompts, beam_search_params)
return outputs return outputs
def encode(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs
def __del__(self): def __del__(self):
del self.model del self.model
cleanup() cleanup()
......
...@@ -9,8 +9,8 @@ from vllm.core.scheduler import Scheduler ...@@ -9,8 +9,8 @@ from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter from vllm.utils import Counter
...@@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): ...@@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
new_token_ids = list(range(num_new_tokens)) new_token_ids = list(range(num_new_tokens))
outputs = [ outputs = [
SequenceGroupOutput( CompletionSequenceGroupOutput(
samples=[ samples=[
SequenceOutput( SequenceOutput(
parent_seq_id=seq.seq_id, parent_seq_id=seq.seq_id,
...@@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, ...@@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
new_token_ids = list(range(num_new_tokens)) new_token_ids = list(range(num_new_tokens))
outputs = [ outputs = [
SequenceGroupOutput( CompletionSequenceGroupOutput(
samples=[ samples=[
SequenceOutput( SequenceOutput(
parent_seq_id=seq.seq_id, parent_seq_id=seq.seq_id,
...@@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, ...@@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
new_token_ids[eos_index] = eos_token_id new_token_ids[eos_index] = eos_token_id
outputs = [ outputs = [
SequenceGroupOutput( CompletionSequenceGroupOutput(
samples=[ samples=[
SequenceOutput( SequenceOutput(
parent_seq_id=seq.seq_id, parent_seq_id=seq.seq_id,
...@@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, ...@@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
new_token_ids[eos_index] = eos_token_id new_token_ids[eos_index] = eos_token_id
outputs = [ outputs = [
SequenceGroupOutput( CompletionSequenceGroupOutput(
samples=[ samples=[
SequenceOutput( SequenceOutput(
parent_seq_id=seq.seq_id, parent_seq_id=seq.seq_id,
......
...@@ -14,6 +14,7 @@ class MockModelConfig: ...@@ -14,6 +14,7 @@ class MockModelConfig:
tokenizer_mode = "auto" tokenizer_mode = "auto"
max_model_len = 100 max_model_len = 100
tokenizer_revision = None tokenizer_revision = None
embedding_mode = False
@dataclass @dataclass
......
...@@ -23,6 +23,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer ...@@ -23,6 +23,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing # technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here # generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora" LORA_NAME = "typeof/zephyr-7b-beta-lora"
...@@ -121,7 +122,7 @@ def zephyr_lora_files(): ...@@ -121,7 +122,7 @@ def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME) return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="session") @pytest.fixture(scope="module")
def server(zephyr_lora_files): def server(zephyr_lora_files):
ray.init() ray.init()
server_runner = ServerRunner.remote([ server_runner = ServerRunner.remote([
...@@ -150,6 +151,25 @@ def server(zephyr_lora_files): ...@@ -150,6 +151,25 @@ def server(zephyr_lora_files):
ray.shutdown() ray.shutdown()
@pytest.fixture(scope="module")
def embedding_server(zephyr_lora_files):
ray.shutdown()
ray.init()
server_runner = ServerRunner.remote([
"--model",
EMBEDDING_MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def client(): def client():
client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
...@@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): ...@@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message) or "less_than_equal" in exc_info.value.message)
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
model_name: str):
input = [
"The chef prepared a delicious meal.",
]
# test single embedding
embeddings = await client.embeddings.create(
model=model_name,
input=input,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 9
assert embeddings.usage.total_tokens == 9
# test using token IDs
input = [1, 1, 1, 1, 1]
embeddings = await client.embeddings.create(
model=model_name,
input=input,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 5
assert embeddings.usage.total_tokens == 5
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
model_name: str):
# test List[str]
inputs = [
"The cat sat on the mat.", "A feline was resting on a rug.",
"Stars twinkle brightly in the night sky."
]
embeddings = await client.embeddings.create(
model=model_name,
input=inputs,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) == 4096
# test List[List[int]]
inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]]
embeddings = await client.embeddings.create(
model=model_name,
input=inputs,
encoding_format="float",
)
assert embeddings.id is not None
assert embeddings.data is not None and len(embeddings.data) == 4
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17
assert embeddings.usage.total_tokens == 17
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_llama_embedding.py`.
"""
import pytest
import torch
import torch.nn.functional as F
MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
def compare_embeddings(embeddings1, embeddings2):
similarities = [
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
for e1, e2 in zip(embeddings1, embeddings2)
]
return similarities
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.encode(example_prompts)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.encode(example_prompts)
del vllm_model
similarities = compare_embeddings(hf_outputs, vllm_outputs)
all_similarities = torch.stack(similarities)
tolerance = 1e-2
assert torch.all((all_similarities <= 1.0 + tolerance)
& (all_similarities >= 1.0 - tolerance)
), f"Not all values are within {tolerance} of 1.0"
...@@ -36,14 +36,14 @@ def test_logits_processor_force_generate( ...@@ -36,14 +36,14 @@ def test_logits_processor_force_generate(
# test logits_processors when prompt_logprobs is not None # test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[0], prompt=example_prompts[0],
sampling_params=params_with_logprobs, params=params_with_logprobs,
prompt_token_ids=None, prompt_token_ids=None,
) )
# test prompt_logprobs is not None # test prompt_logprobs is not None
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[1], prompt=example_prompts[1],
sampling_params=SamplingParams( params=SamplingParams(
prompt_logprobs=3, prompt_logprobs=3,
max_tokens=max_tokens, max_tokens=max_tokens,
), ),
...@@ -53,7 +53,7 @@ def test_logits_processor_force_generate( ...@@ -53,7 +53,7 @@ def test_logits_processor_force_generate(
# test grouped requests # test grouped requests
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[2], prompt=example_prompts[2],
sampling_params=SamplingParams(max_tokens=max_tokens), params=SamplingParams(max_tokens=max_tokens),
prompt_token_ids=None, prompt_token_ids=None,
) )
......
...@@ -60,7 +60,7 @@ def test_random_sample_with_seed( ...@@ -60,7 +60,7 @@ def test_random_sample_with_seed(
llm._add_request( llm._add_request(
prompt=prompt, prompt=prompt,
prompt_token_ids=None, prompt_token_ids=None,
sampling_params=params, params=params,
) )
results = llm._run_engine(use_tqdm=False) results = llm._run_engine(use_tqdm=False)
......
...@@ -7,8 +7,8 @@ import torch ...@@ -7,8 +7,8 @@ import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SamplerOutput, SequenceData, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceGroupMetadata, SequenceGroupOutput, SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput) SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
...@@ -170,7 +170,7 @@ def create_sampler_output_list( ...@@ -170,7 +170,7 @@ def create_sampler_output_list(
return [ return [
SamplerOutput(outputs=[ SamplerOutput(outputs=[
SequenceGroupOutput( CompletionSequenceGroupOutput(
samples=[ samples=[
SequenceOutput( SequenceOutput(
output_token=token_id, output_token=token_id,
......
import pytest import pytest
from tests.core.utils import create_dummy_prompt from tests.core.utils import create_dummy_prompt
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
SequenceOutput) SequenceData, SequenceOutput)
@pytest.fixture @pytest.fixture
def sample_outputs(): def sample_outputs():
return [ return [
SequenceGroupOutput(samples=[ CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
], ],
prompt_logprobs=None) for i in range(5) prompt_logprobs=None) for i in range(5)
...@@ -32,7 +32,7 @@ def test_sampler_output_getitem(sampler_output, sample_outputs): ...@@ -32,7 +32,7 @@ def test_sampler_output_getitem(sampler_output, sample_outputs):
def test_sampler_output_setitem(sampler_output): def test_sampler_output_setitem(sampler_output):
new_output = SequenceGroupOutput(samples=[ new_output = CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
], ],
prompt_logprobs=None) prompt_logprobs=None)
......
...@@ -6,7 +6,9 @@ from vllm.engine.llm_engine import LLMEngine ...@@ -6,7 +6,9 @@ from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.4.2" __version__ = "0.4.2"
...@@ -17,9 +19,12 @@ __all__ = [ ...@@ -17,9 +19,12 @@ __all__ = [
"SamplingParams", "SamplingParams",
"RequestOutput", "RequestOutput",
"CompletionOutput", "CompletionOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"LLMEngine", "LLMEngine",
"EngineArgs", "EngineArgs",
"AsyncLLMEngine", "AsyncLLMEngine",
"AsyncEngineArgs", "AsyncEngineArgs",
"initialize_ray_cluster", "initialize_ray_cluster",
"PoolingParams",
] ]
...@@ -9,6 +9,7 @@ from transformers import PretrainedConfig ...@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config) get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
...@@ -22,6 +23,7 @@ if TYPE_CHECKING: ...@@ -22,6 +23,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_GB = 1 << 30 _GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
class ModelConfig: class ModelConfig:
...@@ -126,6 +128,7 @@ class ModelConfig: ...@@ -126,6 +128,7 @@ class ModelConfig:
served_model_name) served_model_name)
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_embedding_mode()
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
...@@ -137,6 +140,11 @@ class ModelConfig: ...@@ -137,6 +140,11 @@ class ModelConfig:
"either 'auto' or 'slow'.") "either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"] rocm_supported_quantization = ["gptq", "squeezellm"]
...@@ -591,6 +599,7 @@ class SchedulerConfig: ...@@ -591,6 +599,7 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt. prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens. on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
""" """
def __init__( def __init__(
...@@ -602,6 +611,7 @@ class SchedulerConfig: ...@@ -602,6 +611,7 @@ class SchedulerConfig:
num_lookahead_slots: int = 0, num_lookahead_slots: int = 0,
delay_factor: float = 0.0, delay_factor: float = 0.0,
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
) -> None: ) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
...@@ -610,6 +620,10 @@ class SchedulerConfig: ...@@ -610,6 +620,10 @@ class SchedulerConfig:
# It is the values that have the best balance between ITL # It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput. # and TTFT on A100. Note it is not optimized for throughput.
self.max_num_batched_tokens = 512 self.max_num_batched_tokens = 512
elif embedding_mode:
# For embedding, choose specific value for higher throughput
self.max_num_batched_tokens = max(
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
else: else:
# If max_model_len is too short, use 2048 as the default value # If max_model_len is too short, use 2048 as the default value
# for higher throughput. # for higher throughput.
...@@ -623,6 +637,7 @@ class SchedulerConfig: ...@@ -623,6 +637,7 @@ class SchedulerConfig:
self.num_lookahead_slots = num_lookahead_slots self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self._verify_args() self._verify_args()
......
from typing import List, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
"""An embedding version of BlockSpaceManager for use in environments
with embedding models where block management is not required.
This class provides the same interface as BlockSpaceManager, but its
methods perform no actions or return simple values like True in specific
actions. It's designed to be used in scenarios where the overhead of
block management is unnecessary, such as in an embedding environment.
"""
def __init__(
self,
**kwargs,
) -> None:
pass
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# Always return OK for dummy purposes
return AllocStatus.OK
def allocate(self, seq_group: SequenceGroup) -> None:
# No actual allocation logic needed
pass
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
return True
def append_slots(
self,
seq: Sequence,
num_lookahead_slots: int,
) -> List[Tuple[int, int]]:
return None # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
return AllocStatus.OK
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> List[Tuple[int, int]]:
return None # type: ignore
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return True
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
return None # type: ignore
def free(self, seq: Sequence) -> None:
# No operation on free
return
def get_block_table(self, seq: Sequence) -> List[int]:
return None # type: ignore
def get_num_free_gpu_blocks(self) -> int:
return 1
def get_num_free_cpu_blocks(self) -> int:
return 1
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
pass
def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
...@@ -35,6 +35,11 @@ class BlockSpaceManager(ABC): ...@@ -35,6 +35,11 @@ class BlockSpaceManager(ABC):
from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2 return BlockSpaceManagerV2
if version == "embedding":
from vllm.core.embedding_model_block_manager import (
EmbeddingModelBlockSpaceManager)
return EmbeddingModelBlockSpaceManager
raise ValueError(f"Unknown version {version=}") raise ValueError(f"Unknown version {version=}")
@abstractmethod @abstractmethod
......
...@@ -270,9 +270,14 @@ class Scheduler: ...@@ -270,9 +270,14 @@ class Scheduler:
self.scheduler_config.max_model_len, self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
if self.scheduler_config.embedding_mode:
version = "embedding"
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version="v2" if self.scheduler_config. version)
use_v2_block_manager else "v1")
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManagerImpl( self.block_manager = BlockSpaceManagerImpl(
...@@ -968,6 +973,7 @@ class Scheduler: ...@@ -968,6 +973,7 @@ class Scheduler:
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
do_sample=do_sample, do_sample=do_sample,
pooling_params=seq_group.pooling_params,
token_chunk_size=token_chunk_size, token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums, computed_block_nums=common_computed_block_nums,
......
...@@ -574,6 +574,7 @@ class EngineArgs: ...@@ -574,6 +574,7 @@ class EngineArgs:
speculative_config.num_lookahead_slots), speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor, delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
) )
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
......
...@@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine ...@@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -47,15 +48,16 @@ def _raise_exception_on_finish( ...@@ -47,15 +48,16 @@ def _raise_exception_on_finish(
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs for a request that can be """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
iterated over asynchronously.""" that can be iterated over asynchronously."""
def __init__(self, request_id: str) -> None: def __init__(self, request_id: str) -> None:
self.request_id = request_id self.request_id = request_id
self._queue: asyncio.Queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None: def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished: if self._finished:
return return
self._queue.put_nowait(item) self._queue.put_nowait(item)
...@@ -71,7 +73,7 @@ class AsyncStream: ...@@ -71,7 +73,7 @@ class AsyncStream:
def __aiter__(self): def __aiter__(self):
return self return self
async def __anext__(self) -> RequestOutput: async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
result = await self._queue.get() result = await self._queue.get()
if isinstance(result, Exception): if isinstance(result, Exception):
raise result raise result
...@@ -108,7 +110,8 @@ class RequestTracker: ...@@ -108,7 +110,8 @@ class RequestTracker:
self.abort_request(rid) self.abort_request(rid)
def process_request_output(self, def process_request_output(self,
request_output: RequestOutput, request_output: Union[RequestOutput,
EmbeddingRequestOutput],
*, *,
verbose: bool = False) -> None: verbose: bool = False) -> None:
"""Process a request output from the engine.""" """Process a request output from the engine."""
...@@ -196,7 +199,8 @@ class RequestTracker: ...@@ -196,7 +199,8 @@ class RequestTracker:
class _AsyncLLMEngine(LLMEngine): class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods.""" """Extension of LLMEngine to add async methods."""
async def step_async(self) -> List[RequestOutput]: async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
...@@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine):
self, self,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine):
return self.add_request(request_id, return self.add_request(request_id,
prompt=prompt, prompt=prompt,
params=params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
...@@ -511,7 +515,7 @@ class AsyncLLMEngine: ...@@ -511,7 +515,7 @@ class AsyncLLMEngine:
self, self,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -528,9 +532,9 @@ class AsyncLLMEngine: ...@@ -528,9 +532,9 @@ class AsyncLLMEngine:
max_log_len] max_log_len]
logger.info( logger.info(
"Received request %s: prompt: %r, " "Received request %s: prompt: %r, "
"sampling_params: %s, prompt_token_ids: %s, " "params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, "lora_request: %s.", request_id, shortened_prompt, params,
sampling_params, shortened_token_ids, lora_request) shortened_token_ids, lora_request)
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
...@@ -562,7 +566,7 @@ class AsyncLLMEngine: ...@@ -562,7 +566,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
prompt=prompt, prompt=prompt,
sampling_params=sampling_params, params=params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
...@@ -597,8 +601,8 @@ class AsyncLLMEngine: ...@@ -597,8 +601,8 @@ class AsyncLLMEngine:
multi_modal_data: Multi modal data per request. multi_modal_data: Multi modal data per request.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine for the The output `RequestOutput` objects from the LLMEngine
request. for the request.
Details: Details:
- If the engine is not running, start the background loop, - If the engine is not running, start the background loop,
...@@ -643,25 +647,123 @@ class AsyncLLMEngine: ...@@ -643,25 +647,123 @@ class AsyncLLMEngine:
>>> # Process and return the final output >>> # Process and return the final output
>>> ... >>> ...
""" """
# Preprocess the request. async for output in self.process_request(
request_id,
prompt,
sampling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
):
yield output
async def encode(
self,
prompt: Optional[str],
pooling_params: PoolingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "input": "What is LLM?",
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.encode(
>>> example_input["input"],
>>> PoolingParams(),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
request_id,
prompt,
pooling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
):
yield output
async def process_request(
self,
request_id: str,
prompt: Optional[str],
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time = time.time() arrival_time = time.time()
try:
stream = await self.add_request( stream = await self.add_request(
request_id, request_id,
prompt, prompt,
sampling_params, params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )
try:
async for request_output in stream: async for request_output in stream:
yield request_output yield request_output
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id) self._abort(request_id)
raise e raise e
......
...@@ -20,9 +20,12 @@ from vllm.executor.executor_base import ExecutorBase ...@@ -20,9 +20,12 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
MultiModalData, PoolerOutput, SamplerOutput,
Sequence, SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
...@@ -169,6 +172,7 @@ class LLMEngine: ...@@ -169,6 +172,7 @@ class LLMEngine:
load_config=load_config, load_config=load_config,
) )
if not self.model_config.embedding_mode:
self._initialize_kv_caches() self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
...@@ -354,7 +358,7 @@ class LLMEngine: ...@@ -354,7 +358,7 @@ class LLMEngine:
self, self,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -370,7 +374,8 @@ class LLMEngine: ...@@ -370,7 +374,8 @@ class LLMEngine:
request_id: The unique ID of the request. request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is prompt: The prompt string. Can be None if prompt_token_ids is
provided. provided.
sampling_params: The sampling parameters for text generation. params: Parameters for sampling or pooling. SamplingParams
for text generation. PoolingParams for pooling.
prompt_token_ids: The token IDs of the prompt. If None, we prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
...@@ -404,13 +409,6 @@ class LLMEngine: ...@@ -404,13 +409,6 @@ class LLMEngine:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
prompt_token_ids = self.encode_request( prompt_token_ids = self.encode_request(
...@@ -432,6 +430,50 @@ class LLMEngine: ...@@ -432,6 +430,50 @@ class LLMEngine:
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request) eos_token_id, lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def _create_sequence_group_with_sampling(
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone() sampling_params = sampling_params.clone()
...@@ -443,11 +485,35 @@ class LLMEngine: ...@@ -443,11 +485,35 @@ class LLMEngine:
self.generation_config_fields) self.generation_config_fields)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id=request_id,
arrival_time, lora_request, multi_modal_data) seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
# Add the sequence group to the scheduler. return seq_group
self.scheduler.add_seq_group(seq_group)
def _create_sequence_group_with_pooling(
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
pooling_params=pooling_params)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID. """Aborts a request(s) with the given ID.
...@@ -484,13 +550,25 @@ class LLMEngine: ...@@ -484,13 +550,25 @@ class LLMEngine:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()
def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
) -> None:
seq_group.embeddings = outputs[0].embeddings
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _process_model_outputs( def _process_model_outputs(
self, self,
output: List[SamplerOutput], output: List[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup], scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]: ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Apply the model output to the sequences in the scheduled seq groups. """Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client. Returns RequestOutputs that can be returned to the client.
...@@ -510,6 +588,9 @@ class LLMEngine: ...@@ -510,6 +588,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
continue
self.output_processor.process_prompt_logprob(seq_group, outputs) self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample: if seq_group_meta.do_sample:
...@@ -519,18 +600,19 @@ class LLMEngine: ...@@ -519,18 +600,19 @@ class LLMEngine:
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
for scheduled_seq_group in scheduled_seq_groups: for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
for seq_group in ignored_seq_groups: for seq_group in ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
return request_outputs return request_outputs
def step(self) -> List[RequestOutput]: def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png .. figure:: https://i.imgur.com/sv2HssD.png
...@@ -570,7 +652,7 @@ class LLMEngine: ...@@ -570,7 +652,7 @@ class LLMEngine:
>>> while True: >>> while True:
>>> if example_inputs: >>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0) >>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params) >>> engine.add_request(str(req_id),prompt,sampling_params)
>>> >>>
>>> # continue the request processing >>> # continue the request processing
>>> request_outputs = engine.step() >>> request_outputs = engine.step()
...@@ -637,12 +719,15 @@ class LLMEngine: ...@@ -637,12 +719,15 @@ class LLMEngine:
# KV Cache Usage in % # KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() gpu_cache_usage_sys = 0.
if num_total_gpu is not None:
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0. cpu_cache_usage_sys = 0.
if num_total_cpu > 0: if num_total_cpu is not None and num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
) )
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
...@@ -716,7 +801,9 @@ class LLMEngine: ...@@ -716,7 +801,9 @@ class LLMEngine:
seq.get_output_len() seq.get_output_len()
for seq in seq_group.get_finished_seqs() for seq in seq_group.get_finished_seqs()
]) ])
best_of_requests.append(seq_group.sampling_params.best_of) if seq_group.sampling_params is not None:
best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n) n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([ finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status) SequenceStatus.get_finished_reason(seq.status)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment