Unverified Commit 051eaf6d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Add user-configurable task for models that support both generation and embedding (#9424)

parent 7dbe738d
......@@ -221,6 +221,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
......@@ -256,6 +257,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
......@@ -278,12 +280,13 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
### Test overrides for the mapper
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
def test_default_mapper_with_processor_kwargs(image_assets, num_crops):
"""Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs={"num_crops": num_crops},
limit_mm_per_prompt={"image": 1})
......@@ -311,6 +314,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
init_num_crops, inference_num_crops)
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=init_kwargs,
limit_mm_per_prompt={"image": 1})
......@@ -348,6 +352,7 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
# Should filter out the init time kwargs
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
......
......@@ -57,7 +57,8 @@ def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
try:
model_config = ModelConfig(model_path,
model_path,
task="auto",
tokenizer=model_path,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
......
......@@ -2,6 +2,42 @@ import pytest
from vllm.config import ModelConfig
@pytest.mark.parametrize(("model_id", "expected_task"), [
("facebook/opt-125m", "generate"),
("intfloat/e5-mistral-7b-instruct", "embedding"),
])
def test_auto_task(model_id, expected_task):
config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "bad_task"), [
("facebook/opt-125m", "embedding"),
("intfloat/e5-mistral-7b-instruct", "generate"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
ModelConfig(
model_id,
task=bad_task,
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
MODEL_IDS_EXPECTED = [
("Qwen/Qwen1.5-7B", 32768),
("mistralai/Mistral-7B-v0.1", 4096),
......@@ -14,7 +50,8 @@ def test_disable_sliding_window(model_id_expected):
model_id, expected = model_id_expected
model_config = ModelConfig(
model_id,
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
......@@ -32,7 +69,8 @@ def test_get_sliding_window():
# when use_sliding_window is False.
qwen2_model_config = ModelConfig(
"Qwen/Qwen1.5-7B",
"Qwen/Qwen1.5-7B",
task="auto",
tokenizer="Qwen/Qwen1.5-7B",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
......@@ -49,7 +87,8 @@ def test_get_sliding_window():
mistral_model_config = ModelConfig(
"mistralai/Mistral-7B-v0.1",
"mistralai/Mistral-7B-v0.1",
task="auto",
tokenizer="mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
......@@ -70,7 +109,8 @@ def test_rope_customization():
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
task="auto",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
......@@ -82,7 +122,8 @@ def test_rope_customization():
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
task="auto",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
......@@ -98,7 +139,8 @@ def test_rope_customization():
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
task="auto",
tokenizer="lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
......@@ -112,7 +154,8 @@ def test_rope_customization():
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
task="auto",
tokenizer="lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
......
......@@ -59,7 +59,7 @@ def test_deprecate_kwargs_always():
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
......@@ -69,10 +69,10 @@ def test_deprecate_kwargs_never():
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
......@@ -86,15 +86,15 @@ def test_deprecate_kwargs_dynamic():
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
is_deprecated = False
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
......
......@@ -8,7 +8,7 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
import openai
import pytest
......@@ -454,13 +454,13 @@ def multi_process_parallel(
@contextmanager
def error_on_warning():
def error_on_warning(category: Type[Warning] = Warning):
"""
Within the scope of this context manager, tests will fail if any warning
is emitted.
of the given category is emitted.
"""
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.filterwarnings("error", category=category)
yield
......
import enum
import json
from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
Optional, Tuple, Type, Union)
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
import torch
from transformers import PretrainedConfig
......@@ -33,6 +33,9 @@ logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
Task = Literal["generate", "embedding"]
TaskOption = Literal["auto", Task]
class ModelConfig:
"""Configuration for the model.
......@@ -41,6 +44,10 @@ class ModelConfig:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
......@@ -108,6 +115,7 @@ class ModelConfig:
def __init__(self,
model: str,
task: TaskOption,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
......@@ -207,7 +215,11 @@ class ModelConfig:
self.override_neuron_config = override_neuron_config if is_neuron(
) else None
self._verify_embedding_mode()
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
......@@ -241,18 +253,41 @@ class ModelConfig:
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
def _resolve_task(
self,
task_option: TaskOption,
hf_config: PretrainedConfig,
) -> Tuple[Set[Task], Task]:
architectures = getattr(hf_config, "architectures", [])
task_support: Dict[Task, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
}
supported_tasks_lst: List[Task] = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)
# TODO: Allow the same model architecture to be specified as either
# generation or embedding model
if "Phi3VForCausalLM" in architectures:
# Match both remote and local names
embedding_mode = "/VLM2Vec" in self.model
if task_option == "auto":
selected_task = next(iter(supported_tasks_lst))
if len(supported_tasks) > 1:
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
embedding_mode = ModelRegistry.is_embedding_model(architectures)
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)
self.embedding_mode = embedding_mode
selected_task = task_option
return supported_tasks, selected_task
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
......@@ -401,7 +436,7 @@ class ModelConfig:
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.embedding_mode:
if self.task == "embedding":
self.use_async_output_proc = False
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
......@@ -582,11 +617,6 @@ class ModelConfig:
(hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False)))
@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
......@@ -943,6 +973,7 @@ class SchedulerConfig:
"""Scheduler configuration.
Args:
task: The task to use the model for.
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
......@@ -957,7 +988,6 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
......@@ -972,13 +1002,13 @@ class SchedulerConfig:
"""
def __init__(self,
task: Task,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: bool = False,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1,
......@@ -1002,7 +1032,7 @@ class SchedulerConfig:
# for higher throughput.
max_num_batched_tokens = max(max_model_len, 2048)
if embedding_mode:
if task == "embedding":
# For embedding, choose specific value for higher throughput
max_num_batched_tokens = max(
max_num_batched_tokens,
......@@ -1022,12 +1052,12 @@ class SchedulerConfig:
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
self.max_num_batched_tokens)
self.task: Final = task
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
......@@ -1239,6 +1269,7 @@ class SpeculativeConfig:
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
task=target_model_config.task,
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
......
......@@ -313,7 +313,7 @@ class Scheduler:
self.lora_config = lora_config
version = "selfattn"
if (self.scheduler_config.embedding_mode
if (self.scheduler_config.task == "embedding"
or self.cache_config.is_attention_free):
version = "placeholder"
......
......@@ -3,7 +3,7 @@ import dataclasses
import json
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast)
Tuple, Type, Union, cast, get_args)
import torch
......@@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
......@@ -84,6 +84,7 @@ class EngineArgs:
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
......@@ -198,6 +199,15 @@ class EngineArgs:
type=str,
default=EngineArgs.model,
help='Name or path of the huggingface model to use.')
parser.add_argument(
'--task',
default=EngineArgs.task,
choices=get_args(TaskOption),
help='The task to use the model for. Each vLLM instance only '
'supports one task, even if the same model can be used for '
'multiple tasks. When the model only supports one task, "auto" '
'can be used to select it; otherwise, you must specify explicitly '
'which task to use.')
parser.add_argument(
'--tokenizer',
type=nullable_str,
......@@ -838,6 +848,7 @@ class EngineArgs:
def create_model_config(self) -> ModelConfig:
return ModelConfig(
model=self.model,
task=self.task,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
......@@ -1026,13 +1037,13 @@ class EngineArgs:
" please file an issue with detailed information.")
scheduler_config = SchedulerConfig(
task=model_config.task,
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
......
......@@ -344,7 +344,7 @@ class LLMEngine:
observability_config=self.observability_config,
)
if not self.model_config.embedding_mode:
if self.model_config.task != "embedding":
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
......@@ -1116,7 +1116,7 @@ class LLMEngine:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.embedding_mode:
if self.model_config.task == "embedding":
self._process_sequence_group_outputs(seq_group, output)
else:
self.output_processor.process_prompt_logprob(seq_group, output)
......@@ -1855,9 +1855,6 @@ class LLMEngine:
def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()
def is_embedding_model(self):
return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
if self.model_config.is_multimodal_model:
......
......@@ -8,7 +8,7 @@ from tqdm import tqdm
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
......@@ -29,7 +29,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
......@@ -108,6 +108,12 @@ class LLM:
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
"""
A flag to toggle whether to deprecate positional arguments in
:meth:`LLM.__init__`.
"""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
......@@ -117,6 +123,13 @@ class LLM:
cls.DEPRECATE_LEGACY = False
@deprecate_args(
start_index=2, # Ignore self and model
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
additional_message=(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."),
)
def __init__(
self,
model: str,
......@@ -139,6 +152,8 @@ class LLM:
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
**kwargs,
) -> None:
'''
......@@ -153,6 +168,7 @@ class LLM:
engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
......@@ -316,10 +332,21 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
task = self.llm_engine.model_config.task
if task != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).")
"models (XForCausalLM, XForConditionalGeneration).",
]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "generate" in supported_tasks:
messages.append(
"Your model supports the 'generate' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task generate`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
......@@ -692,10 +719,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.encode() is only supported for embedding models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
......@@ -905,6 +940,3 @@ class LLM:
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()
......@@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._enabled = self._check_embedding_mode(model_config.embedding_mode)
self._enabled = self._check_embedding_mode(
model_config.task == "embedding")
async def create_embedding(
self,
......
......@@ -1034,10 +1034,54 @@ def identity(value: T) -> T:
F = TypeVar('F', bound=Callable[..., Any])
def deprecate_args(
start_index: int,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
params = inspect.signature(fn).parameters
pos_types = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
pos_kws = [
kw for kw, param in params.items() if param.kind in pos_types
]
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_args = pos_kws[start_index:len(args)]
if deprecated_args:
msg = (
f"The positional arguments {deprecated_args} are "
"deprecated and will be removed in a future update.")
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
def deprecate_kwargs(
*kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None) -> Callable[[F], F]:
additional_message: Optional[str] = None,
) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
......
......@@ -92,7 +92,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif self._is_embedding_model():
elif model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner
elif self._is_encoder_decoder_model():
ModelRunnerClass = EncoderDecoderModelRunner
......@@ -147,9 +147,6 @@ class Worker(LocalOrDistributedWorkerBase):
def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
def _is_embedding_model(self):
return self.model_config.is_embedding_model
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment