Unverified Commit 8f10d5e3 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Split up pooling tasks (#10820)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 40766ca1
......@@ -381,19 +381,20 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
task = self.llm_engine.model_config.task
if task != "generate":
runner_type = self.llm_engine.model_config.runner_type
if runner_type != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).",
]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "generate" in supported_tasks:
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "generate" in supported_runner_types:
messages.append(
"Your model supports the 'generate' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task generate`.")
"Your model supports the 'generate' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
"Please initialize vLLM using `--task generate`.")
raise ValueError(" ".join(messages))
......@@ -793,16 +794,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.encode() is only supported for embedding models."]
runner_type = self.llm_engine.model_config.runner_type
if runner_type != "pooling":
messages = ["LLM.encode() is only supported for pooling models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "pooling" in supported_runner_types:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
"Your model supports the 'pooling' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
"Please initialize vLLM using `--task embed`, "
"`--task classify`, `--task score` etc.")
raise ValueError(" ".join(messages))
......@@ -864,21 +867,23 @@ class LLM:
A list of ``PoolingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.score() is only supported for embedding models."]
runner_type = self.llm_engine.model_config.runner_type
if runner_type != "pooling":
messages = ["LLM.score() is only supported for pooling models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "pooling" in supported_runner_types:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
"Your model supports the 'pooling' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
"Please initialize vLLM using `--task embed`, "
"`--task classify`, `--task score` etc.")
raise ValueError(" ".join(messages))
if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support the cross encoding")
raise ValueError("Your model does not support cross encoding")
tokenizer = self.llm_engine.get_tokenizer()
......
......@@ -573,7 +573,7 @@ def init_app_state(
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
......@@ -582,7 +582,7 @@ def init_app_state(
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.task == "generate" else None
) if model_config.runner_type == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
......@@ -590,13 +590,13 @@ def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embedding" else None
) if model_config.runner_type == "pooling" else None
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger
) if (model_config.task == "embedding" \
) if (model_config.runner_type == "pooling" \
and model_config.is_cross_encoder) else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
......
......@@ -224,7 +224,7 @@ async def main(args):
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
) if model_config.runner_type == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
......@@ -232,7 +232,7 @@ async def main(args):
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embedding" else None
) if model_config.runner_type == "pooling" else None
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
......
......@@ -35,7 +35,7 @@ def get_model_architecture(
architectures = ["QuantMixtralForCausalLM"]
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embedding":
if model_config.runner_type == "pooling":
model_cls = as_embedding_model(model_cls)
return model_cls, arch
......
......@@ -42,7 +42,7 @@ class EngineCore:
executor_class: Type[Executor],
usage_context: UsageContext,
):
assert vllm_config.model_config.task != "embedding"
assert vllm_config.model_config.runner_type != "pooling"
logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
......
......@@ -163,7 +163,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding":
if self.model_config.runner_type == "pooling":
ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner
......
......@@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.task == "embedding":
if model_config.runner_type == "pooling":
ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment