Unverified Commit e3793961 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Clean up processor kwargs extraction (#35872)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6e9f21e8
...@@ -7,7 +7,8 @@ from transformers.processing_utils import ProcessingKwargs ...@@ -7,7 +7,8 @@ from transformers.processing_utils import ProcessingKwargs
from typing_extensions import Unpack from typing_extensions import Unpack
from vllm.transformers_utils.processor import ( from vllm.transformers_utils.processor import (
get_processor_kwargs_from_processor, get_processor_kwargs_keys,
get_processor_kwargs_type,
) )
...@@ -35,7 +36,7 @@ def _assert_has_all_expected(keys: set[str]) -> None: ...@@ -35,7 +36,7 @@ def _assert_has_all_expected(keys: set[str]) -> None:
assert k in keys assert k in keys
# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs] # Path 1: __call__ method has kwargs: Unpack[*ProcessorKwargs]
class _ProcWithUnpack: class _ProcWithUnpack:
def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore
return None return None
...@@ -43,11 +44,11 @@ class _ProcWithUnpack: ...@@ -43,11 +44,11 @@ class _ProcWithUnpack:
def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union(): def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union():
proc = _ProcWithUnpack() proc = _ProcWithUnpack()
keys = get_processor_kwargs_from_processor(proc) keys = get_processor_kwargs_keys(get_processor_kwargs_type(proc))
_assert_has_all_expected(keys) _assert_has_all_expected(keys)
# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ---- # ---- Path 2: No Unpack, fallback to scanning *ProcessorKwargs in module ----
class _ProcWithoutUnpack: class _ProcWithoutUnpack:
...@@ -62,5 +63,5 @@ def test_get_processor_kwargs_from_processor_module_scan_returns_full_union(): ...@@ -62,5 +63,5 @@ def test_get_processor_kwargs_from_processor_module_scan_returns_full_union():
assert hasattr(mod, "_FakeProcessorKwargs") assert hasattr(mod, "_FakeProcessorKwargs")
proc = _ProcWithoutUnpack() proc = _ProcWithoutUnpack()
keys = get_processor_kwargs_from_processor(proc) keys = get_processor_kwargs_keys(get_processor_kwargs_type(proc))
_assert_has_all_expected(keys) _assert_has_all_expected(keys)
...@@ -111,29 +111,6 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]): ...@@ -111,29 +111,6 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]):
return processor_cls return processor_cls
@lru_cache
def _collect_dynamic_keys_from_processing_kwargs(kwargs_cls: type) -> set[str]:
dynamic_kwargs: set[str] = set()
if kwargs_cls is None:
return dynamic_kwargs
# get kwargs annotations in processor
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
kwargs_type_annotations = get_type_hints(kwargs_cls)
for kw_type in ("text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"):
if kw_type in kwargs_type_annotations:
# Use __annotations__ instead of get_type_hints() to avoid
# NameError from unresolved forward references (e.g.
# PILImageResampling). We only need key names, not types.
kw_cls = kwargs_type_annotations[kw_type]
kw_annotations: dict[str, Any] = {}
for base in reversed(kw_cls.__mro__):
kw_annotations.update(getattr(base, "__annotations__", {}))
for kw_name in kw_annotations:
dynamic_kwargs.add(kw_name)
dynamic_kwargs |= {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
return dynamic_kwargs
def _merge_mm_kwargs( def _merge_mm_kwargs(
model_config: "ModelConfig", model_config: "ModelConfig",
processor_cls: type | tuple[type, ...], processor_cls: type | tuple[type, ...],
...@@ -224,38 +201,63 @@ cached_get_processor = lru_cache(get_processor) ...@@ -224,38 +201,63 @@ cached_get_processor = lru_cache(get_processor)
@lru_cache @lru_cache
def get_processor_kwargs_from_processor(processor: _P) -> set[str]: def get_processor_kwargs_type(
processor: ProcessorMixin,
) -> type[processing_utils.ProcessingKwargs]:
try: try:
# get kwargs annotations in processor # get kwargs annotations in processor
call_kwargs = inspect.signature(type(processor).__call__).parameters.get( call_params = inspect.signature(type(processor).__call__).parameters
"kwargs" call_kwargs = call_params.get("kwargs")
)
call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
# if the processor has explicit kwargs annotation, use it # if the processor has explicit kwargs annotation, use it
if call_kwargs_annotations not in (None, inspect._empty): if call_kwargs_annotations not in (None, inspect._empty):
# get_type_hints will parse all type annotations at runtime, # get_type_hints will parse all type annotations at runtime,
# and if an annotation refers to a type or # and if an annotation refers to a type or
# name that hasn’t been imported or defined, it will raise an error. # name that hasn’t been imported or defined, it will raise an error.
# So we use __annotations__ to get the raw annotations directly. # So we use __annotations__ to get the raw annotations directly.
return _collect_dynamic_keys_from_processing_kwargs( return get_args(call_kwargs_annotations)[0]
get_args(call_kwargs_annotations)[0]
) # otherwise, try to get from ProcessorKwargs
# otherwise, try to get from ProcessingKwargs
else:
module_name = type(processor).__module__ module_name = type(processor).__module__
mod = importlib.import_module(module_name) mod = importlib.import_module(module_name)
# find *ProcessingKwargs in the module
processor_kwargs: set[str] = set()
for name, obj in vars(mod).items(): for name, obj in vars(mod).items():
if name.endswith("ProcessingKwargs"): if name.endswith("ProcessorKwargs"):
processor_kwargs = ( return obj
processor_kwargs
| _collect_dynamic_keys_from_processing_kwargs(obj)
)
return processor_kwargs
except Exception: except Exception:
logger.exception("Failed to collect processor kwargs") logger.exception("Failed to collect processor kwargs")
return set()
return processing_utils.ProcessingKwargs
@lru_cache
def get_processor_kwargs_keys(
kwargs_cls: type[processing_utils.ProcessingKwargs],
) -> set[str]:
dynamic_kwargs: set[str] = set()
modality_kwargs = {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
try:
# get kwargs annotations in processor
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
kwargs_type_annotations = get_type_hints(kwargs_cls)
for kw_type in modality_kwargs:
if kw_type in kwargs_type_annotations:
# Use __annotations__ instead of get_type_hints() to avoid
# NameError from unresolved forward references (e.g.
# PILImageResampling). We only need key names, not types.
kw_cls = kwargs_type_annotations[kw_type]
kw_annotations: dict[str, Any] = {}
for base in reversed(kw_cls.__mro__):
kw_annotations.update(getattr(base, "__annotations__", {}))
for kw_name in kw_annotations:
dynamic_kwargs.add(kw_name)
except Exception:
logger.exception("Failed to collect processor kwargs")
return dynamic_kwargs | modality_kwargs
def cached_get_processor_without_dynamic_kwargs( def cached_get_processor_without_dynamic_kwargs(
...@@ -275,7 +277,9 @@ def cached_get_processor_without_dynamic_kwargs( ...@@ -275,7 +277,9 @@ def cached_get_processor_without_dynamic_kwargs(
) )
# Step 2: use temporary processor collect dynamic keys # Step 2: use temporary processor collect dynamic keys
dynamic_keys = get_processor_kwargs_from_processor(processor) dynamic_keys = get_processor_kwargs_keys(
get_processor_kwargs_type(processor) # type: ignore[arg-type]
)
# Step 3: use dynamic_keys filter kwargs # Step 3: use dynamic_keys filter kwargs
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys} filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment