Unverified Commit 6865fe00 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix interaction between `Optional` and `Annotated` in CLI typing (#19093)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarYikun Jiang <yikun@apache.org>
parent e31446b6
......@@ -5,14 +5,14 @@ import json
from argparse import ArgumentError, ArgumentTypeError
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Annotated, Literal, Optional
import pytest
from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs,
get_type, get_type_hints, is_not_builtin,
is_type, literal_to_kwargs, nullable_kvs,
optional_type, parse_type)
from vllm.utils import FlexibleArgumentParser
......@@ -160,6 +160,18 @@ def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected
@pytest.mark.parametrize(
("type_hint", "expected"), [
(Annotated[int, "annotation"], {int}),
(Optional[int], {int, type(None)}),
(Annotated[Optional[int], "annotation"], {int, type(None)}),
(Optional[Annotated[int, "annotation"]], {int, type(None)}),
],
ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"])
def test_get_type_hints(type_hint, expected):
assert get_type_hints(type_hint) == expected
def test_get_kwargs():
kwargs = get_kwargs(DummyConfig)
print(kwargs)
......
......@@ -15,7 +15,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
import regex as re
import torch
from pydantic import SkipValidation, TypeAdapter, ValidationError
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
......@@ -151,17 +151,29 @@ def is_not_builtin(type_hint: TypeHint) -> bool:
return type_hint.__module__ != "builtins"
def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
"""Extract type hints from Annotated or Union type hints."""
type_hints: set[TypeHint] = set()
origin = get_origin(type_hint)
args = get_args(type_hint)
if origin is Annotated:
type_hints.update(get_type_hints(args[0]))
elif origin is Union:
for arg in args:
type_hints.update(get_type_hints(arg))
else:
type_hints.add(type_hint)
return type_hints
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) in {Union, Annotated}:
predicate = lambda arg: not isinstance(arg, SkipValidation)
type_hints.update(filter(predicate, get_args(field.type)))
else:
type_hints.add(field.type)
type_hints: set[TypeHint] = get_type_hints(field.type)
# If the field is a dataclass, we can use the model_validate_json
generator = (th for th in type_hints if is_dataclass(th))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment