cache_control.rs 3.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use anyhow::Result;
use dynamo_runtime::{
    component::Component,
    pipeline::{PushRouter, RouterMode, SingleIn},
    protocols::annotated::Annotated,
};
use futures::StreamExt;

use crate::protocols::TokenIdType;

/// State captured at routing time for a deferred PIN after generation completes.
pub(crate) struct PinState {
    pub token_ids: Vec<TokenIdType>,
    pub cc_client: CacheControlClient,
    pub instance_id: u64,
    pub ttl_seconds: u64,
}

/// A PushRouter client typed for cache_control requests/responses.
///
/// Both request and response are untyped JSON. The worker's cache_control
/// endpoint returns {"status": "ok"/"error", ...} but the router treats
/// PIN as fire-and-forget and only logs the response at debug level.
pub type CacheControlClient = PushRouter<serde_json::Value, Annotated<serde_json::Value>>;

/// Create a cache_control client from a component.
///
/// Connects to the "cache_control" endpoint on the given component and returns
/// a PushRouter client for sending cache control operations (pin_prefix,
/// unpin_prefix) to workers.
pub(crate) async fn create_cache_control_client(
    component: &Component,
) -> Result<CacheControlClient> {
    let client = component.endpoint("cache_control").client().await?;
    CacheControlClient::from_client(client, RouterMode::KV).await
}

/// Fire-and-forget pin_prefix to the worker that served this request.
///
/// Spawns a detached task that sends the pin request and logs the outcome.
/// Does nothing if `client` is `None` (logs a warning).
pub fn spawn_pin_prefix(
    client: Option<&CacheControlClient>,
    token_ids: &[TokenIdType],
    instance_id: u64,
    context_id: &str,
    ttl_seconds: u64,
) {
    let Some(cc) = client else {
        tracing::warn!(
            request_id = %context_id,
            "cache_control set but no cache_control_client configured"
        );
        return;
    };

    let cc = cc.clone();
    let token_ids = token_ids.to_vec();
    let context_id = context_id.to_owned();

    tokio::spawn(async move {
        let pin_request = serde_json::json!({
            "action": "pin_prefix",
            "token_ids": token_ids,
            "ttl_seconds": ttl_seconds,
        });
        match cc.direct(SingleIn::new(pin_request), instance_id).await {
            Ok(mut stream) => {
                if let Some(resp) = stream.next().await {
                    tracing::info!(
                        request_id = %context_id,
                        worker_id = instance_id,
                        ?resp,
                        "pin_prefix response"
                    );
                }
                // Drain remaining stream to avoid "Failed to publish
                // complete final" errors from the push handler.
                while stream.next().await.is_some() {}
            }
            Err(e) => {
                tracing::warn!(
                    request_id = %context_id,
                    worker_id = instance_id,
                    "Failed to pin prefix: {e}"
                );
            }
        }
    });
}