Unverified Commit 76fd4716 authored by Alec's avatar Alec Committed by GitHub
Browse files

refactor: support for turning prefix cache off (#2034)

parent ac7e8882
......@@ -73,6 +73,12 @@ def parse_args() -> Config:
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
if engine_args.enable_prefix_caching is None:
logger.debug(
"--enable-prefix-caching or --no-enable-prefix-caching not specified. Defaulting to True (vLLM v1 default behavior)"
)
engine_args.enable_prefix_caching = True
config = Config()
config.model = args.model
if args.served_model_name:
......@@ -214,18 +220,22 @@ def overwrite_args(config):
"task": "generate",
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
),
}
if config.engine_args.enable_prefix_caching:
# If caching, send events
defaults |= {
# Always setting up kv events if enable prefix cache.
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_port - dp_rank}", # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
),
)
}
set_side_channel_host_and_port(config)
......
......@@ -173,6 +173,11 @@ async def init(runtime: DistributedRuntime, config: Config):
logger.info(f"VllmWorker for {config.model} has been initialized")
handler = DecodeWorkerHandler(
component, engine_client, default_sampling_params, prefill_worker_client
)
if config.engine_args.enable_prefix_caching:
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
......@@ -189,11 +194,10 @@ async def init(runtime: DistributedRuntime, config: Config):
logger.info(f"Reading Events from {zmq_endpoint}")
handler = DecodeWorkerHandler(
component, engine_client, default_sampling_params, prefill_worker_client
)
handler.kv_publisher = kv_publisher
print(f"FINAL: {engine_client.vllm_config.cache_config.enable_prefix_caching}")
print(f"FINAL: {engine_client.vllm_config.kv_events_config}")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
......
......@@ -8,4 +8,4 @@ trap 'echo Cleaning up...; kill 0' EXIT
dynamo run in=http out=dyn &
# run worker
python3 components/main.py --model Qwen/Qwen3-0.6B --enforce-eager
python3 components/main.py --model Qwen/Qwen3-0.6B --enforce-eager --no-enable-prefix-caching
......@@ -197,6 +197,19 @@ vllm_configs = {
model="Qwen/Qwen3-0.6B",
delayed_start=45,
),
"agg-router": VLLMConfig(
name="agg-router",
directory="/workspace/examples/vllm",
script_name="agg_router.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=45,
),
"disaggregated": VLLMConfig(
name="disaggregated",
directory="/workspace/examples/vllm",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment