Unverified Commit 76fd4716 authored by Alec's avatar Alec Committed by GitHub
Browse files

refactor: support for turning prefix cache off (#2034)

parent ac7e8882
...@@ -73,6 +73,12 @@ def parse_args() -> Config: ...@@ -73,6 +73,12 @@ def parse_args() -> Config:
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
if engine_args.enable_prefix_caching is None:
logger.debug(
"--enable-prefix-caching or --no-enable-prefix-caching not specified. Defaulting to True (vLLM v1 default behavior)"
)
engine_args.enable_prefix_caching = True
config = Config() config = Config()
config.model = args.model config.model = args.model
if args.served_model_name: if args.served_model_name:
...@@ -214,20 +220,24 @@ def overwrite_args(config): ...@@ -214,20 +220,24 @@ def overwrite_args(config):
"task": "generate", "task": "generate",
"skip_tokenizer_init": True, "skip_tokenizer_init": True,
"disable_log_requests": True, "disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics # KV routing relies on logging KV metrics
"disable_log_stats": False, "disable_log_stats": False,
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig( "kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both" kv_connector="NixlConnector", kv_role="kv_both"
), ),
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_port - dp_rank}", # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
),
} }
if config.engine_args.enable_prefix_caching:
# If caching, send events
defaults |= {
# Always setting up kv events if enable prefix cache.
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_port - dp_rank}", # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
)
}
set_side_channel_host_and_port(config) set_side_channel_host_and_port(config)
logger.debug("Setting Dynamo defaults for vLLM") logger.debug("Setting Dynamo defaults for vLLM")
......
...@@ -173,27 +173,31 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -173,27 +173,31 @@ async def init(runtime: DistributedRuntime, config: Config):
logger.info(f"VllmWorker for {config.model} has been initialized") logger.info(f"VllmWorker for {config.model} has been initialized")
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
logger.info(f"Reading Events from {zmq_endpoint}")
handler = DecodeWorkerHandler( handler = DecodeWorkerHandler(
component, engine_client, default_sampling_params, prefill_worker_client component, engine_client, default_sampling_params, prefill_worker_client
) )
handler.kv_publisher = kv_publisher
if config.engine_args.enable_prefix_caching:
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
logger.info(f"Reading Events from {zmq_endpoint}")
handler.kv_publisher = kv_publisher
print(f"FINAL: {engine_client.vllm_config.cache_config.enable_prefix_caching}")
print(f"FINAL: {engine_client.vllm_config.kv_events_config}")
try: try:
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate), generate_endpoint.serve_endpoint(handler.generate),
......
...@@ -8,4 +8,4 @@ trap 'echo Cleaning up...; kill 0' EXIT ...@@ -8,4 +8,4 @@ trap 'echo Cleaning up...; kill 0' EXIT
dynamo run in=http out=dyn & dynamo run in=http out=dyn &
# run worker # run worker
python3 components/main.py --model Qwen/Qwen3-0.6B --enforce-eager python3 components/main.py --model Qwen/Qwen3-0.6B --enforce-eager --no-enable-prefix-caching
...@@ -197,6 +197,19 @@ vllm_configs = { ...@@ -197,6 +197,19 @@ vllm_configs = {
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
delayed_start=45, delayed_start=45,
), ),
"agg-router": VLLMConfig(
name="agg-router",
directory="/workspace/examples/vllm",
script_name="agg_router.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions", "v1/completions"],
response_handlers=[
chat_completions_response_handler,
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=45,
),
"disaggregated": VLLMConfig( "disaggregated": VLLMConfig(
name="disaggregated", name="disaggregated",
directory="/workspace/examples/vllm", directory="/workspace/examples/vllm",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment