metrics_aggregator.rs 5.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::{Arc, Mutex};

pub use crate::kv_router::protocols::ForwardPassMetrics;

use crate::kv_router::scheduler::{Endpoint, Service};
use crate::kv_router::ProcessedEndpoints;
Neelay Shah's avatar
Neelay Shah committed
22
use dynamo_runtime::component::Component;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
use std::time::Duration;
use tokio_util::sync::CancellationToken;

pub struct KvMetricsAggregator {
    pub service_name: String,
    pub endpoints: Arc<Mutex<ProcessedEndpoints>>,
}

impl KvMetricsAggregator {
    pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
        let (ep_tx, mut ep_rx) = tokio::sync::mpsc::channel(128);

        tokio::spawn(collect_endpoints(
            component.drt().nats_client().clone(),
            component.service_name(),
            ep_tx,
            cancellation_token.clone(),
        ));

        tracing::trace!("awaiting the start of the background endpoint subscriber");
        let endpoints = Arc::new(Mutex::new(ProcessedEndpoints::default()));
        let endpoints_clone = endpoints.clone();
        tokio::spawn(async move {
            tracing::debug!("scheduler background task started");
            loop {
                match ep_rx.recv().await {
                    Some(endpoints) => match endpoints_clone.lock() {
                        Ok(mut shared_endpoint) => {
                            *shared_endpoint = endpoints;
                        }
                        Err(e) => {
                            tracing::error!("Failed to acquire lock on endpoints: {:?}", e);
                        }
                    },
                    None => {
                        tracing::warn!("endpoint subscriber shutdown");
                        break;
                    }
                };
            }

            tracing::trace!("background endpoint subscriber shutting down");
        });
        Self {
            service_name: component.service_name(),
            endpoints,
        }
    }

    pub fn get_endpoints(&self) -> ProcessedEndpoints {
        match self.endpoints.lock() {
            Ok(endpoints) => endpoints.clone(),
            Err(e) => {
                tracing::error!("Failed to acquire lock on endpoints: {:?}", e);
                ProcessedEndpoints::default()
            }
        }
    }
}

83
pub async fn collect_endpoints(
Neelay Shah's avatar
Neelay Shah committed
84
    nats_client: dynamo_runtime::transports::nats::Client,
85
86
87
88
    service_name: String,
    ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
    cancel: CancellationToken,
) {
89
90
    let backoff_delay = Duration::from_millis(100);

91
92
93
94
95
96
    loop {
        tokio::select! {
            _ = cancel.cancelled() => {
                tracing::debug!("cancellation token triggered");
                break;
            }
97
            _ = tokio::time::sleep(backoff_delay) => {
98
                tracing::trace!("collecting endpoints for service: {}", service_name);
99
100
101
102
103
104
105
106
107
108
                let values = match nats_client
                    .get_endpoints(&service_name, Duration::from_millis(300))
                    .await
                {
                    Ok(v) => v,
                    Err(e) => {
                        tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
                        continue;
                    }
                };
109

110
111
112
113
114
115
116
117
118
119
120
121
122
                tracing::debug!("values: {:?}", values);
                let services: Vec<Service> = values
                    .into_iter()
                    .filter(|v| !v.is_empty())
                    .filter_map(|v| match serde_json::from_slice::<Service>(&v) {
                        Ok(service) => Some(service),
                        Err(e) => {
                            tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
                            None
                        }
                    })
                    .collect();
                tracing::debug!("services: {:?}", services);
123

124
125
126
127
128
129
130
131
132
133
134
                let endpoints: Vec<Endpoint> = services
                    .into_iter()
                    .flat_map(|s| s.endpoints)
                    .filter(|s| s.data.is_some())
                    .map(|s| Endpoint {
                        name: s.name,
                        subject: s.subject,
                        data: s.data.unwrap(),
                    })
                    .collect();
                tracing::debug!("endpoints: {:?}", endpoints);
135

136
137
138
139
140
                tracing::trace!(
                    "found {} endpoints for service: {}",
                    endpoints.len(),
                    service_name
                );
141

142
143
144
145
146
147
                let processed = ProcessedEndpoints::new(endpoints);
                if ep_tx.send(processed).await.is_err() {
                    tracing::trace!("failed to send processed endpoints; shutting down");
                    break;
                }
            }
148
149
150
        }
    }
}