Unverified Commit a182be43 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[UX][Attention] Add `attention_config` argument to `LLM()` (#30710)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent c01d5898
......@@ -18,6 +18,7 @@ from vllm.beam_search import (
create_sort_beams_key_function,
)
from vllm.config import (
AttentionConfig,
CompilationConfig,
PoolerConfig,
ProfilerConfig,
......@@ -175,6 +176,10 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
attention_config: Configuration for attention mechanisms. Can be a
dictionary or an AttentionConfig instance. If a dictionary, it will
be converted to an AttentionConfig. Allows specifying the attention
backend and other attention-related settings.
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
Note:
......@@ -213,6 +218,7 @@ class LLM:
| StructuredOutputsConfig
| None = None,
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
attention_config: dict[str, Any] | AttentionConfig | None = None,
kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None,
......@@ -252,51 +258,28 @@ class LLM:
if hf_overrides is None:
hf_overrides = {}
if compilation_config is not None:
def _make_config(value: Any, cls: type[_R]) -> _R:
"""Convert dict/None/instance to a config instance."""
if value is None:
return cls()
if isinstance(value, dict):
return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type]
return value
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(
mode=CompilationMode(compilation_config)
)
elif isinstance(compilation_config, dict):
compilation_config_instance = CompilationConfig(
**{
k: v
for k, v in compilation_config.items()
if is_init_field(CompilationConfig, k)
}
)
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = CompilationConfig()
if structured_outputs_config is not None:
if isinstance(structured_outputs_config, dict):
structured_outputs_instance = StructuredOutputsConfig(
**{
k: v
for k, v in structured_outputs_config.items()
if is_init_field(StructuredOutputsConfig, k)
}
compilation_config_instance = _make_config(
compilation_config, CompilationConfig
)
else:
structured_outputs_instance = structured_outputs_config
else:
structured_outputs_instance = StructuredOutputsConfig()
if profiler_config is not None:
if isinstance(profiler_config, dict):
profiler_config_instance = ProfilerConfig(
**{
k: v
for k, v in profiler_config.items()
if is_init_field(ProfilerConfig, k)
}
structured_outputs_instance = _make_config(
structured_outputs_config, StructuredOutputsConfig
)
else:
profiler_config_instance = profiler_config
else:
profiler_config_instance = ProfilerConfig()
profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
attention_config_instance = _make_config(attention_config, AttentionConfig)
# warn about single-process data parallel usage.
_dp_size = int(kwargs.get("data_parallel_size", 1))
......@@ -341,6 +324,7 @@ class LLM:
pooler_config=pooler_config,
structured_outputs_config=structured_outputs_instance,
profiler_config=profiler_config_instance,
attention_config=attention_config_instance,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
**kwargs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment