Unverified Commit 2cf67765 authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

feat: DIS-323 [trtllm backend publisher] only publish kv event with the...

feat: DIS-323 [trtllm backend publisher] only publish kv event with the biggest window size to support kv routing with variable sliding window attention (#2241)
parent 5fad47f7
...@@ -24,4 +24,3 @@ kv_cache_config: ...@@ -24,4 +24,3 @@ kv_cache_config:
- 512 - 512
- 512 - 512
- 32768 - 32768
enable_block_reuse: false
...@@ -24,7 +24,6 @@ kv_cache_config: ...@@ -24,7 +24,6 @@ kv_cache_config:
- 512 - 512
- 512 - 512
- 32768 - 32768
enable_block_reuse: false
cache_transceiver_config: cache_transceiver_config:
backend: default backend: default
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
tensor_parallel_size: 1 tensor_parallel_size: 1
backend: pytorch backend: pytorch
disable_overlap_scheduler: True disable_overlap_scheduler: true
kv_cache_config: kv_cache_config:
max_attention_window: max_attention_window:
...@@ -25,7 +25,6 @@ kv_cache_config: ...@@ -25,7 +25,6 @@ kv_cache_config:
- 512 - 512
- 512 - 512
- 32768 - 32768
enable_block_reuse: false
cache_transceiver_config: cache_transceiver_config:
backend: default backend: default
...@@ -21,10 +21,11 @@ This guide demonstrates how to deploy google/gemma-3-1b-it with Variable Sliding ...@@ -21,10 +21,11 @@ This guide demonstrates how to deploy google/gemma-3-1b-it with Variable Sliding
VSWA is a mechanism in which a model’s layers alternate between multiple sliding window sizes. An example of this is Gemma 3, which incorporates both global attention layers and sliding window layers. VSWA is a mechanism in which a model’s layers alternate between multiple sliding window sizes. An example of this is Gemma 3, which incorporates both global attention layers and sliding window layers.
## Notes ## Notes
* To run Gemma 3 with VSWA, ensure that the container has TensorRT-LLM v1.0.0rc4 installed. * To run Gemma 3 with VSWA and KV Routing with KV block reuse, ensure that the container is built using commit ID `c9eebcb4541d961ab390f0bd0a22e2c89f1bcc78` from Tensorrt-LLM.
```bash
## Limitation ./container/build.sh --framework TENSORRTLLM --tensorrtllm-commit c9eebcb4541d961ab390f0bd0a22e2c89f1bcc78
* The current KV event-based KV routing does not work well with VSWA. The Dynamo team is actively working on adding support to distinguish between events from different layer groups. ```
* The 1.0.0rc4 release version of TensorRT-LLM can also run Gemma 3 with VSWA, but KV block reuse cannot be turned on in that version.
### Aggregated Serving ### Aggregated Serving
```bash ```bash
...@@ -35,6 +36,15 @@ export AGG_ENGINE_ARGS=engine_configs/gemma3/vswa_agg.yaml ...@@ -35,6 +36,15 @@ export AGG_ENGINE_ARGS=engine_configs/gemma3/vswa_agg.yaml
./launch/agg.sh ./launch/agg.sh
``` ```
### Aggregated Serving with KV Routing
```bash
cd $DYNAMO_HOME/components/backends/trtllm
export MODEL_PATH=google/gemma-3-1b-it
export SERVED_MODEL_NAME=$MODEL_PATH
export AGG_ENGINE_ARGS=engine_configs/gemma3/vswa_agg.yaml
./launch/agg_router.sh
```
#### Disaggregated Serving #### Disaggregated Serving
```bash ```bash
cd $DYNAMO_HOME/components/backends/trtllm cd $DYNAMO_HOME/components/backends/trtllm
...@@ -44,3 +54,13 @@ export PREFILL_ENGINE_ARGS=engine_configs/gemma3/vswa_prefill.yaml ...@@ -44,3 +54,13 @@ export PREFILL_ENGINE_ARGS=engine_configs/gemma3/vswa_prefill.yaml
export DECODE_ENGINE_ARGS=engine_configs/gemma3/vswa_decode.yaml export DECODE_ENGINE_ARGS=engine_configs/gemma3/vswa_decode.yaml
./launch/disagg.sh ./launch/disagg.sh
``` ```
#### Disaggregated Serving with KV Routing
```bash
cd $DYNAMO_HOME/components/backends/trtllm
export MODEL_PATH=google/gemma-3-1b-it
export SERVED_MODEL_NAME=$MODEL_PATH
export PREFILL_ENGINE_ARGS=engine_configs/gemma3/vswa_prefill.yaml
export DECODE_ENGINE_ARGS=engine_configs/gemma3/vswa_decode.yaml
./launch/disagg_router.sh
```
...@@ -117,6 +117,12 @@ class Publisher: ...@@ -117,6 +117,12 @@ class Publisher:
self.kv_listener = kv_listener self.kv_listener = kv_listener
self.worker_id = worker_id self.worker_id = worker_id
self.kv_block_size = kv_block_size self.kv_block_size = kv_block_size
self.max_window_size = None
# The first few kv events from the model engine are always "created" type events.
# Use these events to capture the max_window_size of the model.
# When the first event that is not a "created" type is received, the publisher will set this to False to stop processing "created" type events.
self.processing_initial_created_events = True
# Needed by the events and metrics publishers # Needed by the events and metrics publishers
self.metrics_publisher = None self.metrics_publisher = None
...@@ -289,9 +295,14 @@ class Publisher: ...@@ -289,9 +295,14 @@ class Publisher:
events = self.engine.llm.get_kv_cache_events_async(timeout=5) events = self.engine.llm.get_kv_cache_events_async(timeout=5)
async for event in events: async for event in events:
logging.debug(f"KV cache event received: {event}") logging.debug(f"KV cache event received: {event}")
# drop the events that is not emitted from the global attention layer.
if self.should_drop_event(event):
continue
event_id = event["event_id"] event_id = event["event_id"]
data = event["data"] data = event["data"]
if data["type"] == "stored": if data["type"] == "stored":
self.processing_initial_created_events = False
parent_hash = _to_signed_i64(data["parent_hash"]) parent_hash = _to_signed_i64(data["parent_hash"])
token_ids = [] token_ids = []
num_block_tokens = [] num_block_tokens = []
...@@ -332,6 +343,7 @@ class Publisher: ...@@ -332,6 +343,7 @@ class Publisher:
parent_hash, parent_hash,
) )
elif data["type"] == "removed": elif data["type"] == "removed":
self.processing_initial_created_events = False
block_hashes = [] block_hashes = []
for block_hash in data["block_hashes"]: for block_hash in data["block_hashes"]:
block_hash = _to_signed_i64(block_hash) block_hash = _to_signed_i64(block_hash)
...@@ -347,6 +359,9 @@ class Publisher: ...@@ -347,6 +359,9 @@ class Publisher:
f"publish removed event: event_id: {event_id}, block_hashes: {block_hashes}" f"publish removed event: event_id: {event_id}, block_hashes: {block_hashes}"
) )
self.kv_event_publisher.publish_removed(event_id, block_hashes) self.kv_event_publisher.publish_removed(event_id, block_hashes)
elif data["type"] == "created" and self.processing_initial_created_events:
self.update_max_window_size(event)
return True return True
def start(self): def start(self):
...@@ -394,6 +409,42 @@ class Publisher: ...@@ -394,6 +409,42 @@ class Publisher:
if self.publish_kv_cache_events_thread.is_alive(): if self.publish_kv_cache_events_thread.is_alive():
logging.warning("KV cache events thread did not stop within timeout") logging.warning("KV cache events thread did not stop within timeout")
def update_max_window_size(self, event):
if "window_size" in event:
window_size = event["window_size"]
if self.max_window_size is None or window_size > self.max_window_size:
self.max_window_size = window_size
logging.debug(
f"kv events max_window_size has been updated to {self.max_window_size}"
)
# The global attention layer will emit the KV event with the max_window_size.
# We only want to keep the KV event that has the max_window_size to ensure
# the accuracy of KV routing.
# TRTLLM emits a "created" event at the very beginning when it creates the KV cache,
# so we can use the "created" event to identify the max_window_size of the global
# attention layer in the model engine.
def should_drop_event(self, event):
# There are two cases for KV event filtering:
#
# 1. If "window_size" is NOT in the KV event:
# "window_size" was added to KV events only recently, so some older versions of TRTLLM
# might not include it. In this case, the publisher will assume that all events are
# from the global attention layer.
#
# 2. If "window_size" is present in the KV event:
# The publisher will not drop any KV events until all initial "created" KV events
# have been processed in order to capture the max_window_size.
# After processing all "created" events, the publisher will only accept KV events
# whose window_size is equal to the max_window_size to ensure accurate routing.
if "window_size" not in event or self.processing_initial_created_events:
return False
if event["window_size"] != self.max_window_size:
return True
return False
@asynccontextmanager @asynccontextmanager
async def get_publisher(component, engine, kv_listener, worker_id, kv_block_size): async def get_publisher(component, engine, kv_listener, worker_id, kv_block_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment