"docs/vscode:/vscode.git/clone" did not exist on "cc48dd3e35dd8aee39a6f1fcd0ed08f1b56fc2bc"
Unverified Commit 020f58ab authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Core] Support multiple tasks per model (#20771)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c1acd6d7
...@@ -54,7 +54,7 @@ def test_get_field(): ...@@ -54,7 +54,7 @@ def test_get_field():
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
("openai/whisper-small", "transcription", "transcription"), ("openai/whisper-small", "generate", "transcription"),
], ],
) )
def test_auto_task(model_id, expected_runner_type, expected_task): def test_auto_task(model_id, expected_runner_type, expected_task):
...@@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task): ...@@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
) )
assert config.runner_type == expected_runner_type assert config.runner_type == expected_runner_type
assert config.task == expected_task
if config.runner_type == "pooling":
assert config.task == expected_task
else:
assert expected_task in config.supported_tasks
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task): ...@@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
assert config.task == expected_task assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"),
[
("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"),
])
def test_draft_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
runner="draft",
tokenizer=model_id,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
("openai/whisper-small", "generate", "transcription"),
],
)
def test_transcription_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
task="transcription",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "bad_task"), [ @pytest.mark.parametrize(("model_id", "bad_task"), [
("Qwen/Qwen2.5-Math-RM-72B", "generate"), ("Qwen/Qwen2.5-Math-RM-72B", "generate"),
("Qwen/Qwen3-0.6B", "transcription"),
]) ])
def test_incorrect_task(model_id, bad_task): def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"): with pytest.raises(ValueError, match=r"does not support task=.*"):
ModelConfig( ModelConfig(
model_id, model_id,
task=bad_task, task=bad_task,
......
...@@ -91,24 +91,19 @@ logger = init_logger(__name__) ...@@ -91,24 +91,19 @@ logger = init_logger(__name__)
ConfigT = TypeVar("ConfigT", bound=ConfigType) ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"] "score", "reward", "transcription", "draft"]
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft", _ResolvedTask = Literal["generate", "transcription", "pooling", "embed",
"transcription"] "classify", "reward", "draft"]
RunnerType = Literal["generate", "pooling", "draft", "transcription"] RunnerOption = Literal["auto", "generate", "pooling", "draft"]
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { RunnerType = Literal["generate", "pooling", "draft"]
"generate": ["generate"],
"pooling": ["embed", "classify", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = { _RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
task: runner "generate": ["generate", "transcription"],
for runner, tasks in _RUNNER_TASKS.items() "pooling": ["pooling", "embed", "classify", "reward"],
for task in tasks "draft": [],
} }
...@@ -234,11 +229,14 @@ class ModelConfig: ...@@ -234,11 +229,14 @@ class ModelConfig:
"""Name or path of the Hugging Face model to use. It is also used as the """Name or path of the Hugging Face model to use. It is also used as the
content for `model_name` tag in metrics output when `served_model_name` is content for `model_name` tag in metrics output when `served_model_name` is
not specified.""" not specified."""
task: Literal[TaskOption, Literal["draft"]] = "auto" runner: RunnerOption = "auto"
"""The task to use the model for. Each vLLM instance only supports one """The type of model runner to use. Each vLLM instance only supports one
task, even if the same model can be used for multiple tasks. When the model model runner, even if the same model can be used for multiple types."""
only supports one task, "auto" can be used to select it; otherwise, you task: TaskOption = "auto"
must specify explicitly which task to use.""" """The task to use the model for. If the model supports more than one
model runner, this is used to select which model runner to run.
Note that the model may support other tasks using the same model runner."""
tokenizer: SkipValidation[str] = None # type: ignore tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model """Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used.""" name or path will be used."""
...@@ -553,10 +551,41 @@ class ModelConfig: ...@@ -553,10 +551,41 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision) self.model, hf_token=self.hf_token, revision=self.revision)
supported_tasks, task = self._resolve_task(self.task) # For pooling models, self.task is used to indicate the
self.supported_tasks = supported_tasks # user-selected task
self.task = task if self.task == "score":
if self.task in ("draft", "generate"): if self.registry.is_cross_encoder_model(self.architectures):
self.task = "classify"
else:
self.task = "embed"
elif self.task == "embedding":
msg = ("The 'embedding' task has been renamed to 'embed', please "
"use the new name. The old name will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.task = "embed"
all_supported_tasks = self._get_supported_tasks(self.task)
logger.debug("Tasks supported by runner type: %s", all_supported_tasks)
supported_runner_types = self._get_supported_runner_types(
all_supported_tasks)
runner_type = self._resolve_runner(self.runner, self.task,
supported_runner_types,
all_supported_tasks)
logger.debug("Selected runner type: %s", runner_type)
# For pooling models, self.task is used to indicate the
# user-selected task
if runner_type == "pooling" and self.task == "auto":
selected_task = all_supported_tasks[runner_type][-1]
assert selected_task != "pooling"
self.task = selected_task
self.supported_runner_types = supported_runner_types
self.runner_type = runner_type
self.supported_tasks = all_supported_tasks[runner_type]
if self.runner_type in ("draft",
"generate") and self.task != "transcription":
self.truncation_side = "left" self.truncation_side = "left"
else: else:
self.truncation_side = "right" self.truncation_side = "right"
...@@ -780,11 +809,10 @@ class ModelConfig: ...@@ -780,11 +809,10 @@ class ModelConfig:
f"one of {get_args(TokenizerMode)}.") f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _get_preferred_task( def _get_preferred_pooling_task(
self, self,
architectures: list[str], architectures: list[str],
supported_tasks: set[_ResolvedTask], ) -> _ResolvedTask:
) -> Optional[_ResolvedTask]:
model_id = self.model model_id = self.model
if get_pooling_config(model_id, self.revision): if get_pooling_config(model_id, self.revision):
return "embed" return "embed"
...@@ -795,92 +823,136 @@ class ModelConfig: ...@@ -795,92 +823,136 @@ class ModelConfig:
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
# Other models follow this pattern # Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ForSequenceClassification", "classify"), ("ForSequenceClassification", "classify"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embed"), ("EmbeddingModel", "embed"),
("RewardModel", "reward"), ("RewardModel", "reward"),
] ]
_, arch = self.registry.inspect_model_cls(architectures) _, arch = self.registry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task: for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks: if arch.endswith(suffix):
return pref_task return pref_task
return None return "embed"
def _resolve_task( def _get_supported_generation_tasks(
self, self,
task_option: Literal[TaskOption, Literal["draft"]], task_option: TaskOption,
) -> tuple[set[_ResolvedTask], _ResolvedTask]: ) -> list[_ResolvedTask]:
if task_option == "draft": registry = self.registry
return {"draft"}, "draft" architectures = self.architectures
if registry.is_transcription_only_model(architectures):
return ["transcription"]
supported_tasks = list[_ResolvedTask]()
if registry.is_text_generation_model(architectures):
supported_tasks.append("generate")
if registry.is_transcription_model(architectures):
supported_tasks.append("transcription")
return supported_tasks
def _get_supported_pooling_tasks(
self,
task_option: TaskOption,
) -> list[_ResolvedTask]:
registry = self.registry registry = self.registry
architectures = self.architectures architectures = self.architectures
runner_support: dict[RunnerType, bool] = { supported_tasks = list[_ResolvedTask]()
# NOTE: Listed from highest to lowest priority, if registry.is_pooling_model(architectures):
# in case the model supports multiple of them supported_tasks.append("pooling")
"transcription": registry.is_transcription_model(architectures),
"generate": registry.is_text_generation_model(architectures), # For now, users must specify the task (other than "pooling")
"pooling": registry.is_pooling_model(architectures), # to use for pooling models
if task_option == "auto":
preferred_task = self._get_preferred_pooling_task(
architectures)
supported_tasks.append(preferred_task)
elif task_option in _RUNNER_TASKS["pooling"]:
supported_tasks.append(cast(_ResolvedTask, task_option))
return supported_tasks
def _get_supported_tasks(
self,
task_option: TaskOption,
) -> dict[RunnerType, list[_ResolvedTask]]:
return {
"generate": self._get_supported_generation_tasks(task_option),
"pooling": self._get_supported_pooling_tasks(task_option),
"draft": ["draft"]
} }
supported_runner_types_lst: list[RunnerType] = [
runner_type
for runner_type, is_supported in runner_support.items()
if is_supported
]
supported_tasks_lst: list[_ResolvedTask] = [ def _get_supported_runner_types(
task for runner_type in supported_runner_types_lst self,
for task in _RUNNER_TASKS[runner_type] supported_tasks: dict[RunnerType, list[_ResolvedTask]],
] ) -> set[RunnerType]:
supported_tasks = set(supported_tasks_lst) return {
runner
for runner, runner_tasks in supported_tasks.items()
if len(runner_tasks) > 0
}
if task_option == "auto": def _resolve_runner(
selected_task = next(iter(supported_tasks_lst)) self,
runner_option: RunnerOption,
task_option: TaskOption,
supported_runner_types: set[RunnerType],
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
) -> RunnerType:
if not supported_runner_types:
raise ValueError("This model does not support any model runners!")
if runner_option != "auto":
if runner_option not in supported_runner_types:
raise ValueError(
f"This model does not support runner={runner_option!r}. "
f"Available runners: {supported_runner_types}")
if len(supported_tasks_lst) > 1: return runner_option
preferred_task = self._get_preferred_task(
architectures, supported_tasks)
if preferred_task is not None:
selected_task = preferred_task
logger.info( if task_option != "auto":
"This model supports multiple tasks: %s. " for runner, runner_tasks in supported_tasks.items():
"Defaulting to '%s'.", supported_tasks, selected_task) if task_option in runner_tasks:
else: return runner
if task_option == "score":
if not runner_support["pooling"]:
msg = (f"This model does not support the '{task_option}' "
f"task. Supported tasks: {supported_tasks}")
raise ValueError(msg)
if self.registry.is_cross_encoder_model(architectures):
task_option = "classify"
else:
task_option = "embed"
else: else:
# Aliases task_runner: RunnerType = next(
if task_option == "embedding": runner for runner, tasks in _RUNNER_TASKS.items()
msg = ("The 'embedding' task has been renamed to " if task_option in tasks)
"'embed', please use the new name. The old name " raise ValueError(
"will be removed in v1.0.") f"This model does not support task={task_option!r}. "
warnings.warn(msg, DeprecationWarning, stacklevel=2) f"Available tasks for runner={task_runner!r}: "
f"{supported_tasks[task_runner]}")
task_option = "embed" suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("ForSequenceClassification", "pooling"),
("EmbeddingModel", "pooling"),
("RewardModel", "pooling"),
]
_, arch = self.registry.inspect_model_cls(self.architectures)
if task_option not in supported_tasks: for suffix, pref_runner in suffix_to_preferred_runner:
msg = ( if arch.endswith(suffix) and pref_runner in supported_runner_types:
f"This model does not support the '{task_option}' task. " return pref_runner
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)
selected_task = task_option if "classify" in supported_tasks.get("pooling", []):
# When multiple pooling tasks are present, default to
# pooling (eg cross-encoder) for non-standard architectures.
return "pooling"
if "generate" in supported_runner_types:
return "generate"
if "pooling" in supported_runner_types:
return "pooling"
return supported_tasks, selected_task raise AssertionError("This line should not be reached")
def _parse_quant_hf_config(self): def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None) quant_cfg = getattr(self.hf_config, "quantization_config", None)
...@@ -1449,14 +1521,6 @@ class ModelConfig: ...@@ -1449,14 +1521,6 @@ class ModelConfig:
def use_mla(self) -> bool: def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
@property
def supported_runner_types(self) -> set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}
@property
def runner_type(self) -> RunnerType:
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
@property @property
def is_v1_compatible(self) -> bool: def is_v1_compatible(self) -> bool:
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
...@@ -2694,7 +2758,7 @@ class SpeculativeConfig: ...@@ -2694,7 +2758,7 @@ class SpeculativeConfig:
if self.model is not None: if self.model is not None:
self.draft_model_config = ModelConfig( self.draft_model_config = ModelConfig(
model=self.model, model=self.model,
task="draft", runner="draft",
tokenizer=self.target_model_config.tokenizer, tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode, tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config. trust_remote_code=self.target_model_config.
......
...@@ -454,20 +454,19 @@ class LLM: ...@@ -454,20 +454,19 @@ class LLM:
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter. instead pass them via the `inputs` parameter.
""" """
runner_type = self.llm_engine.model_config.runner_type model_config = self.llm_engine.model_config
if runner_type not in ["generate", "transcription"]: runner_type = model_config.runner_type
if runner_type != "generate":
messages = [ messages = [
"LLM.generate() is only supported for (conditional) generation " "LLM.generate() is only supported for generative models."
"models (XForCausalLM, XForConditionalGeneration).",
] ]
supported_runner_types = self.llm_engine.model_config \ if "generate" in model_config.supported_runner_types:
.supported_runner_types
if "generate" in supported_runner_types:
messages.append( messages.append(
"Your model supports the 'generate' runner, but is " "Your model supports the 'generate' runner, but is "
f"currently initialized for the '{runner_type}' runner. " f"currently initialized for the '{runner_type}' runner. "
"Please initialize vLLM using `--task generate`.") "Please initialize vLLM using `--task generate` or "
"`--task transcription`.")
raise ValueError(" ".join(messages)) raise ValueError(" ".join(messages))
...@@ -1091,13 +1090,12 @@ class LLM: ...@@ -1091,13 +1090,12 @@ class LLM:
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter. instead pass them via the `inputs` parameter.
""" """
runner_type = self.llm_engine.model_config.runner_type model_config = self.llm_engine.model_config
runner_type = model_config.runner_type
if runner_type != "pooling": if runner_type != "pooling":
messages = ["LLM.encode() is only supported for pooling models."] messages = ["LLM.encode() is only supported for pooling models."]
supported_runner_types = self.llm_engine.model_config \ if "pooling" in model_config.supported_runner_types:
.supported_runner_types
if "pooling" in supported_runner_types:
messages.append( messages.append(
"Your model supports the 'pooling' runner, but is " "Your model supports the 'pooling' runner, but is "
f"currently initialized for the '{runner_type}' runner. " f"currently initialized for the '{runner_type}' runner. "
...@@ -1119,13 +1117,13 @@ class LLM: ...@@ -1119,13 +1117,13 @@ class LLM:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
elif isinstance(pooling_params, PoolingParams): elif isinstance(pooling_params, PoolingParams):
pooling_params.verify(self.llm_engine.model_config) pooling_params.verify(model_config)
else: else:
for pooling_param in pooling_params: for pooling_param in pooling_params:
pooling_param.verify(self.llm_engine.model_config) pooling_param.verify(model_config)
tokenization_kwargs: dict[str, Any] = {} tokenization_kwargs = dict[str, Any]()
_validate_truncation_size(self.llm_engine.model_config.max_model_len, _validate_truncation_size(model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs) truncate_prompt_tokens, tokenization_kwargs)
self._validate_and_add_requests( self._validate_and_add_requests(
...@@ -1178,9 +1176,10 @@ class LLM: ...@@ -1178,9 +1176,10 @@ class LLM:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
""" """
if self.llm_engine.model_config.task != "embed": model_config = self.llm_engine.model_config
raise ValueError( if "embed" not in model_config.supported_tasks:
"Embedding API is only enabled for `--task embed`") raise ValueError("Embedding API is not supported by this model. "
"Please set `--task embed`.")
items = self.encode(prompts, items = self.encode(prompts,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
...@@ -1223,9 +1222,11 @@ class LLM: ...@@ -1223,9 +1222,11 @@ class LLM:
A list of `ClassificationRequestOutput` objects containing the A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts. embedding vectors in the same order as the input prompts.
""" """
if self.llm_engine.model_config.task != "classify": model_config = self.llm_engine.model_config
if "classify" not in model_config.supported_tasks:
raise ValueError( raise ValueError(
"Classification API is only enabled for `--task classify`") "Classification API is not supported by this model. "
"Please set `--task classify`.")
items = self.encode(prompts, items = self.encode(prompts,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
...@@ -1392,13 +1393,12 @@ class LLM: ...@@ -1392,13 +1393,12 @@ class LLM:
A list of `ScoringRequestOutput` objects containing the A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts. generated scores in the same order as the input prompts.
""" """
runner_type = self.llm_engine.model_config.runner_type model_config = self.llm_engine.model_config
runner_type = model_config.runner_type
if runner_type != "pooling": if runner_type != "pooling":
messages = ["LLM.score() is only supported for pooling models."] messages = ["LLM.score() is only supported for pooling models."]
supported_runner_types = self.llm_engine.model_config \ if "pooling" in model_config.supported_runner_types:
.supported_runner_types
if "pooling" in supported_runner_types:
messages.append( messages.append(
"Your model supports the 'pooling' runner, but is " "Your model supports the 'pooling' runner, but is "
f"currently initialized for the '{runner_type}' runner. " f"currently initialized for the '{runner_type}' runner. "
...@@ -1407,12 +1407,13 @@ class LLM: ...@@ -1407,12 +1407,13 @@ class LLM:
raise ValueError(" ".join(messages)) raise ValueError(" ".join(messages))
if self.llm_engine.model_config.task not in ("embed", "classify"): if all(t not in model_config.supported_tasks
raise ValueError("Score API is only enabled for " for t in ("embed", "classify")):
"`--task embed or --task classify`.") raise ValueError("Score API is not supported by this model. "
"Please set `--task embed` or `--task classify`.")
if (self.llm_engine.model_config.task == "classify" if (model_config.task == "classify"
and self.llm_engine.model_config.hf_config.num_labels != 1): and getattr(model_config.hf_config, "num_labels", 0) != 1):
raise ValueError("Score API is only enabled for num_labels == 1.") raise ValueError("Score API is only enabled for num_labels == 1.")
# the tokenizer for models such as # the tokenizer for models such as
......
...@@ -1520,7 +1520,7 @@ async def init_app_state( ...@@ -1520,7 +1520,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser, reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None ) if "generate" in model_config.supported_tasks else None
state.openai_serving_chat = OpenAIServingChat( state.openai_serving_chat = OpenAIServingChat(
engine_client, engine_client,
model_config, model_config,
...@@ -1537,7 +1537,7 @@ async def init_app_state( ...@@ -1537,7 +1537,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser, reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None ) if "generate" in model_config.supported_tasks else None
state.openai_serving_completion = OpenAIServingCompletion( state.openai_serving_completion = OpenAIServingCompletion(
engine_client, engine_client,
model_config, model_config,
...@@ -1545,7 +1545,7 @@ async def init_app_state( ...@@ -1545,7 +1545,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None ) if "generate" in model_config.supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling( state.openai_serving_pooling = OpenAIServingPooling(
engine_client, engine_client,
model_config, model_config,
...@@ -1553,7 +1553,7 @@ async def init_app_state( ...@@ -1553,7 +1553,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
) if model_config.runner_type == "pooling" else None ) if "pooling" in model_config.supported_tasks else None
state.openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client, engine_client,
model_config, model_config,
...@@ -1561,22 +1561,24 @@ async def init_app_state( ...@@ -1561,22 +1561,24 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None ) if "embed" in model_config.supported_tasks else None
state.openai_serving_classification = ServingClassification( state.openai_serving_classification = ServingClassification(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if model_config.task == "classify" else None ) if "classify" in model_config.supported_tasks else None
enable_serving_reranking = (model_config.task == "classify" and getattr( enable_serving_reranking = ("classify" in model_config.supported_tasks
model_config.hf_config, "num_labels", 0) == 1) and getattr(model_config.hf_config,
"num_labels", 0) == 1)
state.openai_serving_scores = ServingScores( state.openai_serving_scores = ServingScores(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger) if ( request_logger=request_logger,
model_config.task == "embed" or enable_serving_reranking) else None ) if ("embed" in model_config.supported_tasks
or enable_serving_reranking) else None
state.openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client, engine_client,
...@@ -1591,13 +1593,13 @@ async def init_app_state( ...@@ -1591,13 +1593,13 @@ async def init_app_state(
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if model_config.runner_type == "transcription" else None ) if "transcription" in model_config.supported_tasks else None
state.openai_serving_translation = OpenAIServingTranslation( state.openai_serving_translation = OpenAIServingTranslation(
engine_client, engine_client,
model_config, model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if model_config.runner_type == "transcription" else None ) if "transcription" in model_config.supported_tasks else None
state.task = model_config.task state.task = model_config.task
state.enable_server_load_tracking = args.enable_server_load_tracking state.enable_server_load_tracking = args.enable_server_load_tracking
......
...@@ -348,7 +348,7 @@ async def main(args): ...@@ -348,7 +348,7 @@ async def main(args):
chat_template=None, chat_template=None,
chat_template_content_format="auto", chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.runner_type == "generate" else None ) if "generate" in model_config.supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding( openai_serving_embedding = OpenAIServingEmbedding(
engine, engine,
model_config, model_config,
...@@ -356,17 +356,19 @@ async def main(args): ...@@ -356,17 +356,19 @@ async def main(args):
request_logger=request_logger, request_logger=request_logger,
chat_template=None, chat_template=None,
chat_template_content_format="auto", chat_template_content_format="auto",
) if model_config.task == "embed" else None ) if "embed" in model_config.supported_tasks else None
enable_serving_reranking = (model_config.task == "classify" and getattr( enable_serving_reranking = ("classify" in model_config.supported_tasks
model_config.hf_config, "num_labels", 0) == 1) and getattr(model_config.hf_config,
"num_labels", 0) == 1)
openai_serving_scores = (ServingScores( openai_serving_scores = ServingScores(
engine, engine,
model_config, model_config,
openai_serving_models, openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) if (model_config.task == "embed" or enable_serving_reranking) else None) ) if ("embed" in model_config.supported_tasks
or enable_serving_reranking) else None
tracker = BatchProgressTracker() tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file) logger.info("Reading batch from %s...", args.input_file)
......
...@@ -694,6 +694,12 @@ class SupportsTranscription(Protocol): ...@@ -694,6 +694,12 @@ class SupportsTranscription(Protocol):
supports_transcription: ClassVar[Literal[True]] = True supports_transcription: ClassVar[Literal[True]] = True
supports_transcription_only: ClassVar[bool] = False
"""
Transcription models can opt out of text generation by setting this to
`True`.
"""
@classmethod @classmethod
def get_generation_prompt(cls, audio: np.ndarray, def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str, stt_config: SpeechToTextConfig, language: str,
......
...@@ -284,6 +284,7 @@ class _ModelInfo: ...@@ -284,6 +284,7 @@ class _ModelInfo:
is_hybrid: bool is_hybrid: bool
has_noops: bool has_noops: bool
supports_transcription: bool supports_transcription: bool
supports_transcription_only: bool
supports_v0_only: bool supports_v0_only: bool
@staticmethod @staticmethod
...@@ -299,6 +300,8 @@ class _ModelInfo: ...@@ -299,6 +300,8 @@ class _ModelInfo:
is_attention_free=is_attention_free(model), is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model), is_hybrid=is_hybrid(model),
supports_transcription=supports_transcription(model), supports_transcription=supports_transcription(model),
supports_transcription_only=(supports_transcription(model) and
model.supports_transcription_only),
supports_v0_only=supports_v0_only(model), supports_v0_only=supports_v0_only(model),
has_noops=has_noops(model), has_noops=has_noops(model),
) )
...@@ -573,6 +576,13 @@ class _ModelRegistry: ...@@ -573,6 +576,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_transcription return model_cls.supports_transcription
def is_transcription_only_model(
self,
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_transcription_only
def is_v1_compatible( def is_v1_compatible(
self, self,
architectures: Union[str, list[str]], architectures: Union[str, list[str]],
......
...@@ -772,6 +772,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -772,6 +772,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
".fc2.": ".mlp.fc2." ".fc2.": ".mlp.fc2."
}) })
# Whisper only supports audio-conditioned generation.
supports_transcription_only = True
@classmethod @classmethod
def validate_language(cls, language: str) -> bool: def validate_language(cls, language: str) -> bool:
if language in ISO639_1_SUPPORTED_LANGS: if language in ISO639_1_SUPPORTED_LANGS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment