"vscode:/vscode.git/clone" did not exist on "aa05dfd5f34c81ba6ef5e791fdad10e96688090b"
Unverified Commit 4b2ed792 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve configs - the rest! (#17562)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 7e357113
......@@ -9,7 +9,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test
......@@ -95,9 +95,6 @@ def test_full_graph(
run_model(optimization_level, model, model_kwargs)
PassConfig = CompilationConfig.PassConfig
# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
"compilation_config, model_info",
......
......@@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from .backend import TestBackend
......@@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch.set_default_device("cuda")
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config= \
CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True))
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
......
......@@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
......@@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \
CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
PassConfig(enable_fusion=True, enable_noop=True)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
......
......@@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn_maybe, is_func)
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
VllmConfig)
PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
......@@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(
enable_sequence_parallelism=True, ), )
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
enable_sequence_parallelism=True))
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
......
......@@ -6,7 +6,7 @@ import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from .backend import TestBackend
......@@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True))
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass)
......
......@@ -206,7 +206,7 @@ def _compare_sp(
'compile_sizes': [4, 8],
'splitting_ops': [],
'pass_config': {
'enable_sequence_parallism': sp_enabled,
'enable_sequence_parallelism': sp_enabled,
'enable_noop': True,
'enable_fusion': True,
},
......@@ -223,7 +223,7 @@ def _compare_sp(
"--distributed-executor-backend",
distributed_backend,
"--compilation_config",
str(compilation_config),
json.dumps(compilation_config),
]
tp_env = {
......
......@@ -8,21 +8,18 @@ from typing import Literal, Optional
import pytest
from vllm.config import config
from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs,
optional_type)
optional_type, parse_type)
from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("type", "value", "expected"), [
(int, "42", 42),
(int, "None", None),
(float, "3.14", 3.14),
(float, "None", None),
(str, "Hello World!", "Hello World!"),
(str, "None", None),
(json.loads, '{"foo":1,"bar":2}', {
"foo": 1,
"bar": 2
......@@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
"foo": 1,
"bar": 2
}),
(json.loads, "None", None),
])
def test_optional_type(type, value, expected):
optional_type_func = optional_type(type)
def test_parse_type(type, value, expected):
parse_type_func = parse_type(type)
context = nullcontext()
if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning)
with context:
assert optional_type_func(value) == expected
assert parse_type_func(value) == expected
def test_optional_type():
optional_type_func = optional_type(int)
assert optional_type_func("None") is None
assert optional_type_func("42") == 42
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
......@@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
@config
@dataclass
class DummyConfigClass:
class NestedConfig:
field: int = 1
"""field"""
@config
@dataclass
class FromCliConfig1:
field: int = 1
"""field"""
@classmethod
def from_cli(cls, cli_value: str):
inst = cls(**json.loads(cli_value))
inst.field += 1
return inst
@config
@dataclass
class FromCliConfig2:
field: int = 1
"""field"""
@classmethod
def from_cli(cls, cli_value: str):
inst = cls(**json.loads(cli_value))
inst.field += 2
return inst
@config
@dataclass
class DummyConfig:
regular_bool: bool = True
"""Regular bool with default True"""
optional_bool: Optional[bool] = None
......@@ -108,18 +143,24 @@ class DummyConfigClass:
"""Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
"""Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
"""Nested config"""
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
"""Config with from_cli method"""
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
"""Different config with from_cli method"""
@pytest.mark.parametrize(("type_hint", "expected"), [
(int, False),
(DummyConfigClass, True),
(DummyConfig, True),
])
def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected
def test_get_kwargs():
kwargs = get_kwargs(DummyConfigClass)
kwargs = get_kwargs(DummyConfig)
print(kwargs)
# bools should not have their type set
......@@ -142,6 +183,11 @@ def test_get_kwargs():
# dict should have json tip in help
json_tip = "\n\nShould be a valid JSON string."
assert kwargs["json_tip"]["help"].endswith(json_tip)
# nested config should should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
# from_cli configs should be constructed with the correct method
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
@pytest.mark.parametrize(("arg", "expected"), [
......@@ -177,7 +223,7 @@ def test_compilation_config():
# default value
args = parser.parse_args([])
assert args.compilation_config is None
assert args.compilation_config == CompilationConfig()
# set to O3
args = parser.parse_args(["-O3"])
......@@ -194,7 +240,7 @@ def test_compilation_config():
# set to string form of a dict
args = parser.parse_args([
"--compilation-config",
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
......@@ -202,7 +248,7 @@ def test_compilation_config():
# set to string form of a dict
args = parser.parse_args([
"--compilation-config="
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
......
......@@ -4,7 +4,7 @@ import time
import torch
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import PassConfig, VllmConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
......@@ -56,10 +56,7 @@ class VllmInductorPass(InductorPass):
class PrinterInductorPass(VllmInductorPass):
def __init__(self,
name: str,
config: CompilationConfig.PassConfig,
always=False):
def __init__(self, name: str, config: PassConfig, always=False):
super().__init__(config)
self.name = name
self.always = always
......
This diff is collapsed.
......@@ -5,6 +5,7 @@ import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import asdict
from itertools import count
from queue import Queue
from typing import Any, Callable, Optional, Union
......@@ -284,7 +285,7 @@ class EventPublisherFactory:
if not config:
return NullEventPublisher()
config_dict = config.model_dump()
config_dict = asdict(config)
kind = config_dict.pop("publisher", "null")
config_dict.pop("enable_kv_cache_events")
......
......@@ -7,10 +7,10 @@ import json
import re
import threading
import warnings
from dataclasses import MISSING, dataclass, fields
from dataclasses import MISSING, dataclass, fields, is_dataclass
from itertools import permutations
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin)
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
Type, TypeVar, Union, cast, get_args, get_origin)
import torch
from typing_extensions import TypeIs, deprecated
......@@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
is_in_ray_actor)
# yapf: enable
......@@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
def _parse_type(val: str) -> T:
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
......@@ -62,14 +60,24 @@ def optional_type(
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _parse_type
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
return parse_type(return_type)(val)
return _optional_type
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val):
return str(val)
else:
return optional_type(json.loads)(val)
return optional_type(json.loads)(val)
@deprecated(
......@@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) in {Union, Annotated}:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# If the field is a dataclass, we can use the model_validate_json
generator = (th for th in type_hints if is_dataclass(th))
dataclass_cls = next(generator, None)
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
if field.default is not MISSING:
default = field.default
elif field.default_factory is not MISSING:
if is_dataclass(field.default_factory) and is_in_doc_build():
default = {}
else:
default = field.default_factory()
# Get the help text for the field
name = field.name
......@@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
json_tip = "\n\nShould be a valid JSON string."
if contains_type(type_hints, bool):
if dataclass_cls is not None:
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
# Special case for configs with a from_cli method
if hasattr(dataclass_cls, "from_cli"):
from_cli = dataclass_cls.from_cli
dataclass_init = lambda x, f=from_cli: f(x)
kwargs[name]["type"] = dataclass_init
kwargs[name]["help"] += json_tip
elif contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
......@@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
kwargs[name]["type"] = parse_type(json.loads)
kwargs[name]["help"] += json_tip
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
......@@ -771,63 +795,20 @@ class EngineArgs:
scheduler_group.add_argument("--scheduler-cls",
**scheduler_kwargs["scheduler_cls"])
# Compilation arguments
# compilation_kwargs = get_kwargs(CompilationConfig)
compilation_group = parser.add_argument_group(
title="CompilationConfig",
description=CompilationConfig.__doc__,
)
compilation_group.add_argument(
"--compilation-config",
"-O",
type=CompilationConfig.from_cli,
default=None,
help="torch.compile configuration for the model. "
"When it is a number (0, 1, 2, 3), it will be "
"interpreted as the optimization level.\n"
"NOTE: level 0 is the default level without "
"any optimization. level 1 and 2 are for internal "
"testing only. level 3 is the recommended level "
"for production.\n"
"To specify the full compilation config, "
"use a JSON string, e.g. ``{\"level\": 3, "
"\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n"
"Following the convention of traditional "
"compilers, using ``-O`` without space is also "
"supported. ``-O3`` is equivalent to ``-O 3``.")
# KVTransfer arguments
# kv_transfer_kwargs = get_kwargs(KVTransferConfig)
kv_transfer_group = parser.add_argument_group(
title="KVTransferConfig",
description=KVTransferConfig.__doc__,
)
kv_transfer_group.add_argument(
"--kv-transfer-config",
type=KVTransferConfig.from_cli,
default=None,
help="The configurations for distributed KV cache "
"transfer. Should be a JSON string.")
kv_transfer_group.add_argument(
'--kv-events-config',
type=KVEventsConfig.from_cli,
default=None,
help='The configurations for event publishing.')
# vLLM arguments
# vllm_kwargs = get_kwargs(VllmConfig)
vllm_kwargs = get_kwargs(VllmConfig)
vllm_group = parser.add_argument_group(
title="VllmConfig",
description=VllmConfig.__doc__,
)
vllm_group.add_argument(
"--additional-config",
type=json.loads,
default=None,
help="Additional config for specified platform in JSON format. "
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{\"config_key\":\"config_value\"}'")
vllm_group.add_argument("--kv-transfer-config",
**vllm_kwargs["kv_transfer_config"])
vllm_group.add_argument('--kv-events-config',
**vllm_kwargs["kv_events_config"])
vllm_group.add_argument("--compilation-config", "-O",
**vllm_kwargs["compilation_config"])
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])
# Other arguments
parser.add_argument('--use-v2-block-manager',
......
......@@ -13,7 +13,8 @@ from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
is_init_field)
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine
......@@ -204,9 +205,13 @@ class LLM:
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
str(compilation_config))
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(
level=compilation_config)
elif isinstance(compilation_config, dict):
predicate = lambda x: is_init_field(CompilationConfig, x[0])
compilation_config_instance = CompilationConfig(
**dict(filter(predicate, compilation_config.items())))
else:
compilation_config_instance = compilation_config
else:
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
import torch
from tpu_info import device
......@@ -13,9 +13,10 @@ from vllm.sampling_params import SamplingParams, SamplingType
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
else:
BlockSize = None
ModelConfig = None
VllmConfig = None
PoolingParams = None
......@@ -94,7 +95,7 @@ class TpuPlatform(Platform):
cache_config = vllm_config.cache_config
# For v0, the default block size is 16.
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
cache_config.block_size = cast(BlockSize, 16)
compilation_config = vllm_config.compilation_config
# TPU only supports DYNAMO_ONCE compilation level
......@@ -118,7 +119,7 @@ class TpuPlatform(Platform):
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config)
vllm_config) # type: ignore[assignment]
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > cache_config.block_size:
......@@ -128,7 +129,7 @@ class TpuPlatform(Platform):
cache_config.block_size,
min_page_size,
)
cache_config.block_size = min_page_size
cache_config.block_size = min_page_size # type: ignore[assignment]
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
......
......@@ -1820,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
return isinstance(zmq, _MockModule)
except ModuleNotFoundError:
return False
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
"""
Import a Python file according to its file path.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment