sequence.rs 15.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
//! Runtime-specific glue for [`ActiveSequencesMultiWorker`].
5
//!
6
7
8
9
10
11
12
13
//! This module provides the concrete [`SequencePublisher`] and [`SequenceSubscriber`]
//! implementations that wire the runtime-agnostic business logic (in `dynamo_kv_router`)
//! to NATS event transport and Prometheus metrics.

pub use dynamo_kv_router::multi_worker_sequence::{
    ActiveSequencesMultiWorker, SequenceError, SequencePublisher, SequenceRequest,
    SequenceSubscriber,
};
14
use dynamo_kv_router::protocols::{ActiveLoad, ActiveSequenceEvent, WorkerWithDpRank};
15
pub use dynamo_kv_router::sequence::{ActiveSequences, RequestId};
16

17
18
19
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
20
use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber};
21
use std::collections::HashMap;
22
23
use std::sync::Arc;

24
use super::metrics::WORKER_LOAD_METRICS;
25
use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT};
Yan Ru Pei's avatar
Yan Ru Pei committed
26
use crate::local_model::runtime_config::ModelRuntimeConfig;
27

28
29
/// Concrete [`SequencePublisher`] backed by NATS [`EventPublisher`] and Prometheus gauges.
pub struct RuntimeSequencePublisher {
30
    event_publisher: EventPublisher,
31
    metrics_publisher: Arc<EventPublisher>,
32
33
}

34
35
36
impl SequencePublisher for RuntimeSequencePublisher {
    async fn publish_event(&self, event: &ActiveSequenceEvent) -> anyhow::Result<()> {
        self.event_publisher.publish(event).await
37
38
    }

39
    fn publish_load(&self, load: ActiveLoad) {
40
41
        let publisher = self.metrics_publisher.clone();
        tokio::spawn(async move {
42
            if let Err(e) = publisher.publish(&load).await {
43
                tracing::trace!(
44
45
46
                    "Failed to publish ActiveLoad to NATS for worker (id={}, dp_rank={}): {e:?}",
                    load.worker_id,
                    load.dp_rank
47
48
49
                );
            }
        });
50
51
    }

52
    fn observe_load(
53
        &self,
54
55
56
57
58
59
60
61
62
63
64
65
        worker: &WorkerWithDpRank,
        worker_type: &str,
        blocks: usize,
        tokens: usize,
    ) {
        WORKER_LOAD_METRICS.observe(
            worker.worker_id,
            worker.dp_rank,
            worker_type,
            blocks,
            tokens,
        );
66
    }
67
}
68

69
70
71
72
/// Concrete [`SequenceSubscriber`] backed by NATS typed event stream.
pub struct RuntimeSequenceSubscriber {
    inner: dynamo_runtime::transports::event_plane::TypedEventSubscriber<ActiveSequenceEvent>,
}
73

74
75
76
77
78
impl SequenceSubscriber for RuntimeSequenceSubscriber {
    async fn next_event(&mut self) -> Option<anyhow::Result<ActiveSequenceEvent>> {
        match self.inner.next().await? {
            Ok((_envelope, event)) => Some(Ok(event)),
            Err(e) => Some(Err(e)),
79
        }
80
    }
81
}
82

83
84
/// Type alias for the runtime-wired multi-worker sequence tracker.
pub type ActiveSequencesMulti = ActiveSequencesMultiWorker<RuntimeSequencePublisher>;
85

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
/// Convenience async constructor that creates the NATS publishers/subscribers
/// and returns an `Arc<ActiveSequencesMulti>` with replica sync already running.
pub async fn create_multi_worker_sequences(
    component: Component,
    block_size: usize,
    workers_with_configs: HashMap<u64, ModelRuntimeConfig>,
    replica_sync: bool,
    router_id: u64,
    worker_type: &'static str,
) -> Result<Arc<ActiveSequencesMulti>> {
    let event_publisher =
        EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?;
    let metrics_publisher =
        Arc::new(EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?);

    let publisher = RuntimeSequencePublisher {
        event_publisher,
        metrics_publisher,
    };

106
    let dp_range: HashMap<u64, (u32, u32)> = workers_with_configs
107
        .into_iter()
108
109
110
111
112
113
        .map(|(id, config)| {
            (
                id,
                (config.data_parallel_start_rank, config.data_parallel_size),
            )
        })
114
115
116
117
118
        .collect();

    let multi_worker = ActiveSequencesMultiWorker::new(
        publisher,
        block_size,
119
        dp_range,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        replica_sync,
        router_id,
        worker_type,
    );

    let arc = Arc::new(multi_worker);

    if replica_sync {
        let subscriber = EventSubscriber::for_component(&component, ACTIVE_SEQUENCES_SUBJECT)
            .await?
            .typed::<ActiveSequenceEvent>();
        let subscriber = RuntimeSequenceSubscriber { inner: subscriber };
        let cancel_token = component.drt().runtime().child_token();
        arc.start_replica_sync(subscriber, cancel_token);
134
    }
135

136
137
138
    let expiry_cancel = component.drt().runtime().child_token();
    arc.start_periodic_force_expiry_across_all_workers(expiry_cancel);

139
    Ok(arc)
140
141
142
143
144
}

#[cfg(test)]
mod tests {
    use super::*;
145
    use dynamo_runtime::{DistributedRuntime, Runtime};
146

147
148
149
150
151
    #[test]
    fn test_active_sequences_shared_blocks() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

152
        seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
153
154
155
        assert_eq!(seq_manager.active_blocks(), 3);
        assert_eq!(seq_manager.active_tokens(), 12);

156
        seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
157
158
159
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 16);

160
        seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 16);

        seq_manager.free(&"request_2".to_string());
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.free(&"request_3".to_string());
        assert_eq!(seq_manager.active_blocks(), 3);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.free(&"request_1".to_string());
        assert_eq!(seq_manager.active_blocks(), 0);
        assert_eq!(seq_manager.active_tokens(), 0);
    }

177
    #[tokio::test]
178
    #[ignore]
179
    async fn test_multi_worker_cross_instance_sync() -> Result<()> {
180
181
        dynamo_runtime::logging::init();

182
        let block_size = 4;
183

184
185
        let runtime = Runtime::from_current()?;
        let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
186

187
        let namespace = distributed.namespace("test_cross_instance_sync")?;
188
        let component = namespace.component("sequences")?;
189

Yan Ru Pei's avatar
Yan Ru Pei committed
190
191
192
193
        let mut workers_with_configs = HashMap::new();

        let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
        config_worker_0.data_parallel_size = 2;
194
        workers_with_configs.insert(0, config_worker_0);
Yan Ru Pei's avatar
Yan Ru Pei committed
195
196

        let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
197
        workers_with_configs.insert(1, config_worker_1);
Yan Ru Pei's avatar
Yan Ru Pei committed
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        let seq_manager_1 = create_multi_worker_sequences(
            component.clone(),
            block_size,
            workers_with_configs.clone(),
            true,
            1,
            crate::discovery::WORKER_TYPE_DECODE,
        )
        .await?;
        let seq_manager_2 = create_multi_worker_sequences(
            component,
            block_size,
            workers_with_configs,
            true,
            2,
            crate::discovery::WORKER_TYPE_DECODE,
        )
        .await?;
217
218
219
220

        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;

        seq_manager_1
221
222
223
224
225
            .add_request(SequenceRequest {
                request_id: "request_0".to_string(),
                token_sequence: Some(vec![0, 1, 2]),
                isl: 12,
                overlap: 0,
226
                track_prefill_tokens: true,
227
228
229
230
                expected_output_tokens: None,
                worker: WorkerWithDpRank::new(0, 0),
                lora_name: None,
            })
231
            .await?;
232

233
        seq_manager_1
234
235
236
237
238
            .add_request(SequenceRequest {
                request_id: "request_1".to_string(),
                token_sequence: Some(vec![3, 4]),
                isl: 8,
                overlap: 0,
239
                track_prefill_tokens: true,
240
241
242
243
                expected_output_tokens: None,
                worker: WorkerWithDpRank::new(0, 1),
                lora_name: None,
            })
244
            .await?;
245

246
        seq_manager_2
247
248
249
250
251
            .add_request(SequenceRequest {
                request_id: "request_2".to_string(),
                token_sequence: Some(vec![0, 1, 2, 3]),
                isl: 16,
                overlap: 0,
252
                track_prefill_tokens: true,
253
254
255
256
                expected_output_tokens: None,
                worker: WorkerWithDpRank::new(1, 0),
                lora_name: None,
            })
257
            .await?;
258

259
        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
260

261
262
        let blocks_phase1 = seq_manager_1.active_blocks();
        let tokens_phase1 = seq_manager_1.active_tokens();
263

Yan Ru Pei's avatar
Yan Ru Pei committed
264
265
266
267
        let worker_0_dp0 = WorkerWithDpRank::new(0, 0);
        let worker_0_dp1 = WorkerWithDpRank::new(0, 1);
        let worker_1_dp0 = WorkerWithDpRank::new(1, 0);

268
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
269
270
            blocks_phase1[&worker_0_dp0], 3,
            "Worker 0 dp_rank 0 should have 3 active blocks (from request_0)"
271
        );
272
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
273
274
            blocks_phase1[&worker_0_dp1], 2,
            "Worker 0 dp_rank 1 should have 2 active blocks (from request_1)"
275
276
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
277
278
            blocks_phase1[&worker_1_dp0], 4,
            "Worker 1 dp_rank 0 should have 4 active blocks (from request_2 added by seq_manager_2)"
279
280
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
281
282
            tokens_phase1[&worker_0_dp0], 12,
            "Worker 0 dp_rank 0 should have 12 active tokens"
283
        );
284
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
285
286
287
288
289
290
            tokens_phase1[&worker_0_dp1], 8,
            "Worker 0 dp_rank 1 should have 8 active tokens"
        );
        assert_eq!(
            tokens_phase1[&worker_1_dp0], 16,
            "Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
291
        );
292
293
294
295
296
297
298
299

        seq_manager_1.free(&"request_2".to_string()).await?;

        seq_manager_2.free(&"request_0".to_string()).await?;
        seq_manager_2.free(&"request_1".to_string()).await?;

        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

300
301
        let blocks_phase2 = seq_manager_2.active_blocks();
        let tokens_phase2 = seq_manager_2.active_tokens();
302

Yan Ru Pei's avatar
Yan Ru Pei committed
303
304
305
306
307
308
309
        let all_workers = vec![
            WorkerWithDpRank::new(0, 0),
            WorkerWithDpRank::new(0, 1),
            WorkerWithDpRank::new(1, 0),
        ];

        for worker in all_workers {
310
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
311
312
313
                blocks_phase2[&worker], 0,
                "Worker (id={}, dp_rank={}) should have 0 active blocks after all requests freed",
                worker.worker_id, worker.dp_rank
314
315
            );
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
316
317
318
                tokens_phase2[&worker], 0,
                "Worker (id={}, dp_rank={}) should have 0 active tokens after all requests freed",
                worker.worker_id, worker.dp_rank
319
320
321
322
323
324
325
326
327
328
329
            );
        }

        Ok(())
    }

    #[tokio::test]
    #[ignore]
    async fn test_multi_worker_no_token_sequence_sync() -> Result<()> {
        dynamo_runtime::logging::init();

330
        let block_size = 4;
331
332
333
334
335

        let runtime = Runtime::from_current()?;
        let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;

        let namespace = distributed.namespace("test_no_token_seq_sync")?;
336
        let component = namespace.component("sequences")?;
337

Yan Ru Pei's avatar
Yan Ru Pei committed
338
        let mut workers_with_configs = HashMap::new();
339
340
341
342
343
344
345
346
347
348
349
350
        workers_with_configs.insert(
            0,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
        workers_with_configs.insert(
            1,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
        workers_with_configs.insert(
            2,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
Yan Ru Pei's avatar
Yan Ru Pei committed
351

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        let seq_manager_1 = create_multi_worker_sequences(
            component.clone(),
            block_size,
            workers_with_configs.clone(),
            true,
            1,
            crate::discovery::WORKER_TYPE_DECODE,
        )
        .await?;
        let seq_manager_2 = create_multi_worker_sequences(
            component,
            block_size,
            workers_with_configs,
            true,
            2,
            crate::discovery::WORKER_TYPE_DECODE,
        )
        .await?;
370
371
372
373

        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;

        seq_manager_1
374
375
376
377
378
            .add_request(SequenceRequest {
                request_id: "request_0".to_string(),
                token_sequence: None,
                isl: 12,
                overlap: 0,
379
                track_prefill_tokens: true,
380
381
382
383
                expected_output_tokens: None,
                worker: WorkerWithDpRank::from_worker_id(0),
                lora_name: None,
            })
384
385
386
            .await?;

        seq_manager_1
387
388
389
390
391
            .add_request(SequenceRequest {
                request_id: "request_1".to_string(),
                token_sequence: None,
                isl: 8,
                overlap: 0,
392
                track_prefill_tokens: true,
393
394
395
396
                expected_output_tokens: None,
                worker: WorkerWithDpRank::from_worker_id(1),
                lora_name: None,
            })
397
398
399
            .await?;

        seq_manager_2
400
401
402
403
404
            .add_request(SequenceRequest {
                request_id: "request_2".to_string(),
                token_sequence: None,
                isl: 16,
                overlap: 0,
405
                track_prefill_tokens: true,
406
407
408
409
                expected_output_tokens: None,
                worker: WorkerWithDpRank::from_worker_id(2),
                lora_name: None,
            })
410
411
412
413
            .await?;

        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

414
        let tokens_phase1 = seq_manager_1.active_tokens();
415

Yan Ru Pei's avatar
Yan Ru Pei committed
416
417
418
419
        let worker_0 = WorkerWithDpRank::from_worker_id(0);
        let worker_1 = WorkerWithDpRank::from_worker_id(1);
        let worker_2 = WorkerWithDpRank::from_worker_id(2);

420
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
421
            tokens_phase1[&worker_0], 12,
422
            "Worker 0 should have 12 active tokens"
423
424
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
425
426
427
428
429
            tokens_phase1[&worker_1], 8,
            "Worker 1 should have 8 active tokens"
        );
        assert_eq!(
            tokens_phase1[&worker_2], 16,
430
            "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
431
        );
432

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        seq_manager_1
            .mark_prefill_completed(&"request_2".to_string())
            .await?;
        seq_manager_1.free(&"request_2".to_string()).await?;

        seq_manager_2
            .mark_prefill_completed(&"request_0".to_string())
            .await?;
        seq_manager_2
            .mark_prefill_completed(&"request_1".to_string())
            .await?;
        seq_manager_2.free(&"request_0".to_string()).await?;
        seq_manager_2.free(&"request_1".to_string()).await?;

        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

449
        let tokens_phase2 = seq_manager_2.active_tokens();
450
451

        for worker_id in 0..=2 {
Yan Ru Pei's avatar
Yan Ru Pei committed
452
            let worker = WorkerWithDpRank::from_worker_id(worker_id);
453
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
454
                tokens_phase2[&worker], 0,
455
456
457
458
459
                "Worker {} should have 0 active tokens after all requests freed",
                worker_id
            );
        }

460
        Ok(())
461
462
    }
}