sequence.rs 14.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
//! Runtime-specific glue for [`ActiveSequencesMultiWorker`].
5
//!
6
7
8
9
10
11
12
13
//! This module provides the concrete [`SequencePublisher`] and [`SequenceSubscriber`]
//! implementations that wire the runtime-agnostic business logic (in `dynamo_kv_router`)
//! to NATS event transport and Prometheus metrics.

pub use dynamo_kv_router::multi_worker_sequence::{
    ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
    SequenceSubscriber,
};
14
use dynamo_kv_router::protocols::{ActiveLoad, ActiveSequenceEvent, WorkerWithDpRank};
15
pub use dynamo_kv_router::sequence::{ActiveSequences, RequestId};
16

17
18
19
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
20
use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber};
21
use std::collections::HashMap;
22
23
use std::sync::Arc;

24
use super::metrics::WORKER_LOAD_METRICS;
25
use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT};
Yan Ru Pei's avatar
Yan Ru Pei committed
26
use crate::local_model::runtime_config::ModelRuntimeConfig;
27
28
#[cfg(test)]
use dynamo_kv_router::protocols::PrefillLoadHint;
29

30
31
/// Concrete [`SequencePublisher`] backed by NATS [`EventPublisher`] and Prometheus gauges.
pub struct RuntimeSequencePublisher {
32
    event_publisher: EventPublisher,
33
    metrics_publisher: Arc<EventPublisher>,
34
35
}

36
37
38
impl SequencePublisher for RuntimeSequencePublisher {
    async fn publish_event(&self, event: &ActiveSequenceEvent) -> anyhow::Result<()> {
        self.event_publisher.publish(event).await
39
40
    }

41
    fn publish_load(&self, load: ActiveLoad) {
42
43
        let publisher = self.metrics_publisher.clone();
        tokio::spawn(async move {
44
            if let Err(e) = publisher.publish(&load).await {
45
                tracing::trace!(
46
47
48
                    "Failed to publish ActiveLoad to NATS for worker (id={}, dp_rank={}): {e:?}",
                    load.worker_id,
                    load.dp_rank
49
50
51
                );
            }
        });
52
53
    }

54
    fn observe_load(
55
        &self,
56
57
58
59
60
61
62
63
64
65
66
67
        worker: &WorkerWithDpRank,
        worker_type: &str,
        blocks: usize,
        tokens: usize,
    ) {
        WORKER_LOAD_METRICS.observe(
            worker.worker_id,
            worker.dp_rank,
            worker_type,
            blocks,
            tokens,
        );
68
    }
69
}
70

71
72
73
74
/// Concrete [`SequenceSubscriber`] backed by NATS typed event stream.
pub struct RuntimeSequenceSubscriber {
    inner: dynamo_runtime::transports::event_plane::TypedEventSubscriber<ActiveSequenceEvent>,
}
75

76
77
78
79
80
impl SequenceSubscriber for RuntimeSequenceSubscriber {
    async fn next_event(&mut self) -> Option<anyhow::Result<ActiveSequenceEvent>> {
        match self.inner.next().await? {
            Ok((_envelope, event)) => Some(Ok(event)),
            Err(e) => Some(Err(e)),
81
        }
82
    }
83
}
84

85
86
/// Type alias for the runtime-wired multi-worker sequence tracker.
pub type ActiveSequencesMulti = ActiveSequencesMultiWorker<RuntimeSequencePublisher>;
87

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/// Convenience async constructor that creates the NATS publishers/subscribers
/// and returns an `Arc<ActiveSequencesMulti>` with replica sync already running.
pub async fn create_multi_worker_sequences(
    component: Component,
    block_size: usize,
    workers_with_configs: HashMap<u64, ModelRuntimeConfig>,
    replica_sync: bool,
    router_id: u64,
    worker_type: &'static str,
) -> Result<Arc<ActiveSequencesMulti>> {
    let event_publisher =
        EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?;
    let metrics_publisher =
        Arc::new(EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?);

    let publisher = RuntimeSequencePublisher {
        event_publisher,
        metrics_publisher,
    };

108
    let dp_range: HashMap<u64, (u32, u32)> = workers_with_configs
109
        .into_iter()
110
111
112
113
114
115
        .map(|(id, config)| {
            (
                id,
                (config.data_parallel_start_rank, config.data_parallel_size),
            )
        })
116
117
118
119
120
        .collect();

    let multi_worker = ActiveSequencesMultiWorker::new(
        publisher,
        block_size,
121
        dp_range,
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        replica_sync,
        router_id,
        worker_type,
    );

    let arc = Arc::new(multi_worker);

    if replica_sync {
        let subscriber = EventSubscriber::for_component(&component, ACTIVE_SEQUENCES_SUBJECT)
            .await?
            .typed::<ActiveSequenceEvent>();
        let subscriber = RuntimeSequenceSubscriber { inner: subscriber };
        let cancel_token = component.drt().runtime().child_token();
        arc.start_replica_sync(subscriber, cancel_token);
136
    }
137

138
139
140
    let expiry_cancel = component.drt().runtime().child_token();
    arc.start_periodic_force_expiry_across_all_workers(expiry_cancel);

141
    Ok(arc)
142
143
144
145
146
}

#[cfg(test)]
mod tests {
    use super::*;
147
    use dynamo_runtime::{DistributedRuntime, Runtime};
148
    use tokio::time::Instant;
149

150
151
152
153
154
155
156
    fn tracking_hint(tokens: usize) -> Option<PrefillLoadHint> {
        Some(PrefillLoadHint {
            initial_effective_prefill_tokens: tokens,
            expected_prefill_duration: None,
        })
    }

157
    #[tokio::test]
158
    #[ignore]
159
    async fn test_multi_worker_cross_instance_sync() -> Result<()> {
160
161
        dynamo_runtime::logging::init();

162
        let block_size = 4;
163

164
165
        let runtime = Runtime::from_current()?;
        let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
166

167
        let namespace = distributed.namespace("test_cross_instance_sync")?;
168
        let component = namespace.component("sequences")?;
169

Yan Ru Pei's avatar
Yan Ru Pei committed
170
171
172
173
        let mut workers_with_configs = HashMap::new();

        let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
        config_worker_0.data_parallel_size = 2;
174
        workers_with_configs.insert(0, config_worker_0);
Yan Ru Pei's avatar
Yan Ru Pei committed
175
176

        let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
177
        workers_with_configs.insert(1, config_worker_1);
Yan Ru Pei's avatar
Yan Ru Pei committed
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        let seq_manager_1 = create_multi_worker_sequences(
            component.clone(),
            block_size,
            workers_with_configs.clone(),
            true,
            1,
            crate::discovery::WORKER_TYPE_DECODE,
        )
        .await?;
        let seq_manager_2 = create_multi_worker_sequences(
            component,
            block_size,
            workers_with_configs,
            true,
            2,
            crate::discovery::WORKER_TYPE_DECODE,
        )
        .await?;
197
198

        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
199
        let decay_now = Instant::now();
200

201
202
203
204
205
206
        seq_manager_1.add_request(
            SequenceRequest {
                request_id: "request_0".to_string(),
                token_sequence: Some(vec![0, 1, 2]),
                track_prefill_tokens: true,
                expected_output_tokens: None,
207
                prefill_load_hint: tracking_hint(12),
208
209
210
                worker: WorkerWithDpRank::new(0, 0),
                lora_name: None,
            },
211
            decay_now,
212
213
214
215
216
217
218
219
        )?;

        seq_manager_1.add_request(
            SequenceRequest {
                request_id: "request_1".to_string(),
                token_sequence: Some(vec![3, 4]),
                track_prefill_tokens: true,
                expected_output_tokens: None,
220
                prefill_load_hint: tracking_hint(8),
221
222
223
                worker: WorkerWithDpRank::new(0, 1),
                lora_name: None,
            },
224
            decay_now,
225
226
227
228
229
230
231
232
        )?;

        seq_manager_2.add_request(
            SequenceRequest {
                request_id: "request_2".to_string(),
                token_sequence: Some(vec![0, 1, 2, 3]),
                track_prefill_tokens: true,
                expected_output_tokens: None,
233
                prefill_load_hint: tracking_hint(16),
234
235
236
                worker: WorkerWithDpRank::new(1, 0),
                lora_name: None,
            },
237
            decay_now,
238
        )?;
239

240
        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
241

242
        let blocks_phase1 = seq_manager_1.active_blocks();
243
        let tokens_phase1 = seq_manager_1.active_tokens(Instant::now());
244

Yan Ru Pei's avatar
Yan Ru Pei committed
245
246
247
248
        let worker_0_dp0 = WorkerWithDpRank::new(0, 0);
        let worker_0_dp1 = WorkerWithDpRank::new(0, 1);
        let worker_1_dp0 = WorkerWithDpRank::new(1, 0);

249
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
250
251
            blocks_phase1[&worker_0_dp0], 3,
            "Worker 0 dp_rank 0 should have 3 active blocks (from request_0)"
252
        );
253
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
254
255
            blocks_phase1[&worker_0_dp1], 2,
            "Worker 0 dp_rank 1 should have 2 active blocks (from request_1)"
256
257
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
258
259
            blocks_phase1[&worker_1_dp0], 4,
            "Worker 1 dp_rank 0 should have 4 active blocks (from request_2 added by seq_manager_2)"
260
261
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
262
263
            tokens_phase1[&worker_0_dp0], 12,
            "Worker 0 dp_rank 0 should have 12 active tokens"
264
        );
265
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
266
267
268
269
270
271
            tokens_phase1[&worker_0_dp1], 8,
            "Worker 0 dp_rank 1 should have 8 active tokens"
        );
        assert_eq!(
            tokens_phase1[&worker_1_dp0], 16,
            "Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
272
        );
273

274
        seq_manager_1.free(&"request_2".to_string(), Instant::now())?;
275

276
277
        seq_manager_2.free(&"request_0".to_string(), Instant::now())?;
        seq_manager_2.free(&"request_1".to_string(), Instant::now())?;
278
279
280

        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

281
        let blocks_phase2 = seq_manager_2.active_blocks();
282
        let tokens_phase2 = seq_manager_2.active_tokens(Instant::now());
283

Yan Ru Pei's avatar
Yan Ru Pei committed
284
285
286
287
288
289
290
        let all_workers = vec![
            WorkerWithDpRank::new(0, 0),
            WorkerWithDpRank::new(0, 1),
            WorkerWithDpRank::new(1, 0),
        ];

        for worker in all_workers {
291
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
292
293
294
                blocks_phase2[&worker], 0,
                "Worker (id={}, dp_rank={}) should have 0 active blocks after all requests freed",
                worker.worker_id, worker.dp_rank
295
296
            );
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
297
298
299
                tokens_phase2[&worker], 0,
                "Worker (id={}, dp_rank={}) should have 0 active tokens after all requests freed",
                worker.worker_id, worker.dp_rank
300
301
302
303
304
305
306
307
308
309
310
            );
        }

        Ok(())
    }

    #[tokio::test]
    #[ignore]
    async fn test_multi_worker_no_token_sequence_sync() -> Result<()> {
        dynamo_runtime::logging::init();

311
        let block_size = 4;
312
313
314
315
316

        let runtime = Runtime::from_current()?;
        let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;

        let namespace = distributed.namespace("test_no_token_seq_sync")?;
317
        let component = namespace.component("sequences")?;
318

Yan Ru Pei's avatar
Yan Ru Pei committed
319
        let mut workers_with_configs = HashMap::new();
320
321
322
323
324
325
326
327
328
329
330
331
        workers_with_configs.insert(
            0,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
        workers_with_configs.insert(
            1,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
        workers_with_configs.insert(
            2,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
Yan Ru Pei's avatar
Yan Ru Pei committed
332

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        let seq_manager_1 = create_multi_worker_sequences(
            component.clone(),
            block_size,
            workers_with_configs.clone(),
            true,
            1,
            crate::discovery::WORKER_TYPE_DECODE,
        )
        .await?;
        let seq_manager_2 = create_multi_worker_sequences(
            component,
            block_size,
            workers_with_configs,
            true,
            2,
            crate::discovery::WORKER_TYPE_DECODE,
        )
        .await?;
351
352

        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
353
        let decay_now = Instant::now();
354

355
356
357
358
359
360
        seq_manager_1.add_request(
            SequenceRequest {
                request_id: "request_0".to_string(),
                token_sequence: None,
                track_prefill_tokens: true,
                expected_output_tokens: None,
361
                prefill_load_hint: tracking_hint(12),
362
363
364
                worker: WorkerWithDpRank::from_worker_id(0),
                lora_name: None,
            },
365
            decay_now,
366
367
368
369
370
371
372
373
        )?;

        seq_manager_1.add_request(
            SequenceRequest {
                request_id: "request_1".to_string(),
                token_sequence: None,
                track_prefill_tokens: true,
                expected_output_tokens: None,
374
                prefill_load_hint: tracking_hint(8),
375
376
377
                worker: WorkerWithDpRank::from_worker_id(1),
                lora_name: None,
            },
378
            decay_now,
379
380
381
382
383
384
385
386
        )?;

        seq_manager_2.add_request(
            SequenceRequest {
                request_id: "request_2".to_string(),
                token_sequence: None,
                track_prefill_tokens: true,
                expected_output_tokens: None,
387
                prefill_load_hint: tracking_hint(16),
388
389
390
                worker: WorkerWithDpRank::from_worker_id(2),
                lora_name: None,
            },
391
            decay_now,
392
        )?;
393
394
395

        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

396
        let tokens_phase1 = seq_manager_1.active_tokens(Instant::now());
397

Yan Ru Pei's avatar
Yan Ru Pei committed
398
399
400
401
        let worker_0 = WorkerWithDpRank::from_worker_id(0);
        let worker_1 = WorkerWithDpRank::from_worker_id(1);
        let worker_2 = WorkerWithDpRank::from_worker_id(2);

402
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
403
            tokens_phase1[&worker_0], 12,
404
            "Worker 0 should have 12 active tokens"
405
406
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
407
408
409
410
411
            tokens_phase1[&worker_1], 8,
            "Worker 1 should have 8 active tokens"
        );
        assert_eq!(
            tokens_phase1[&worker_2], 16,
412
            "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
413
        );
414

415
416
        seq_manager_1.mark_prefill_completed(&"request_2".to_string(), Instant::now())?;
        seq_manager_1.free(&"request_2".to_string(), Instant::now())?;
417

418
419
420
421
        seq_manager_2.mark_prefill_completed(&"request_0".to_string(), Instant::now())?;
        seq_manager_2.mark_prefill_completed(&"request_1".to_string(), Instant::now())?;
        seq_manager_2.free(&"request_0".to_string(), Instant::now())?;
        seq_manager_2.free(&"request_1".to_string(), Instant::now())?;
422
423
424

        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

425
        let tokens_phase2 = seq_manager_2.active_tokens(Instant::now());
426
427

        for worker_id in 0..=2 {
Yan Ru Pei's avatar
Yan Ru Pei committed
428
            let worker = WorkerWithDpRank::from_worker_id(worker_id);
429
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
430
                tokens_phase2[&worker], 0,
431
432
433
434
435
                "Worker {} should have 0 active tokens after all requests freed",
                worker_id
            );
        }

436
        Ok(())
437
438
    }
}