lib.rs 34.7 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use dynamo_llm::local_model::LocalModel;
5
use dynamo_runtime::distributed::{DistributedConfig, RequestPlaneMode};
6
use dynamo_runtime::storage::kv;
7
8
use futures::StreamExt;
use once_cell::sync::OnceCell;
9
use pyo3::IntoPyObjectExt;
10
use pyo3::exceptions::PyStopAsyncIteration;
Richard Huo's avatar
Richard Huo committed
11
use pyo3::types::PyCapsule;
12
use pyo3::types::{PyDict, PyString};
13
14
use pyo3::{exceptions::PyException, prelude::*};
use rs::pipeline::network::Ingress;
Richard Huo's avatar
Richard Huo committed
15
use std::ffi::CString;
16
use std::fs;
17
use std::path::PathBuf;
Richard Huo's avatar
Richard Huo committed
18
19
20
21
use std::{
    fmt::Display,
    sync::{Arc, Weak},
};
22
use tokio::sync::Mutex;
23
use tracing::Instrument;
24

25
use dynamo_runtime::config::environment_names::logging::otlp as env_otlp;
Neelay Shah's avatar
Neelay Shah committed
26
use dynamo_runtime::{
Ryan Olson's avatar
Ryan Olson committed
27
    self as rs, logging,
28
    pipeline::{
29
        AsyncEngineContextProvider, EngineStream, ManyOut, SingleIn, context::Context as RsContext,
30
        network::egress::push_router::RouterMode as RsRouterMode,
31
    },
32
    protocols::annotated::Annotated as RsAnnotated,
33
    traits::DistributedRuntimeProvider,
34
35
};

Neelay Shah's avatar
Neelay Shah committed
36
use dynamo_llm::{self as llm_rs};
37
38
use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig};

39
use crate::llm::local_model::ModelRuntimeConfig;
40
use crate::llm::preprocessor::{MediaDecoder, MediaFetcher};
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
pub enum RouterMode {
    RoundRobin,
    Random,
    KV,
}

impl From<RouterMode> for RsRouterMode {
    fn from(mode: RouterMode) -> Self {
        match mode {
            RouterMode::RoundRobin => Self::RoundRobin,
            RouterMode::Random => Self::Random,
            RouterMode::KV => Self::KV,
        }
    }
}
59

60
mod context;
61
mod engine;
62
mod http;
63
mod kserve_grpc;
64
mod llm;
65
mod parsers;
66
mod planner;
67
mod prometheus_metrics;
68
69
70
71
72
73
74
75

type JsonServerStreamingIngress =
    Ingress<SingleIn<serde_json::Value>, ManyOut<RsAnnotated<serde_json::Value>>>;

static INIT: OnceCell<()> = OnceCell::new();

const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true);

76
77
// Helper to get appropriate span for instrumentation - always emit spans
fn get_span_for_context(context: &context::Context, operation: &str) -> tracing::Span {
78
79
80
81
82
83
    logging::make_client_request_span(
        operation,
        context.inner().id(),
        context.trace_context(),
        None,
    )
84
85
86
87
88
89
90
91
}

// Helper to create span for direct method with instance_id
fn get_span_for_direct_context(
    context: &context::Context,
    operation: &str,
    instance_id: &str,
) -> tracing::Span {
92
93
94
95
96
97
    logging::make_client_request_span(
        operation,
        context.inner().id(),
        context.trace_context(),
        Some(instance_id),
    )
98
99
}

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Helper to create request context with proper linking and cancellation handling
fn create_request_context(
    request: serde_json::Value,
    parent_ctx: &Option<context::Context>,
) -> RsContext<serde_json::Value> {
    match parent_ctx {
        // If there is a parent context, link the request as a child context of it
        Some(parent_ctx) => {
            let child_ctx = RsContext::with_id(request, parent_ctx.inner().id().to_string());
            parent_ctx.inner().link_child(child_ctx.context());
            if parent_ctx.inner().is_stopped() || parent_ctx.inner().is_killed() {
                // Let the server handle the cancellation for now since not all backends are
                // properly handling request exceptions
                // TODO: (DIS-830) Return an error if context is cancelled
                child_ctx.context().stop_generating();
            }
            child_ctx
        }
        // Otherwise if there is no parent context, use the request as-is
        _ => request.into(),
    }
}

123
124
125
126
127
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
128
    // Initialize logging early unless OTEL export is enabled (which requires tokio runtime)
129
130
131
132
    if std::env::var(env_otlp::OTEL_EXPORT_ENABLED)
        .map(|v| v == "1")
        .unwrap_or(false)
    {
133
        eprintln!(
134
            "Warning: OTEL_EXPORT_ENABLED detected. Logging initialization deferred until runtime is available. Early logs may be dropped."
135
136
137
138
139
        );
    } else {
        rs::logging::init();
    }

Yan Ru Pei's avatar
Yan Ru Pei committed
140
    m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?;
141
    m.add_function(wrap_pyfunction!(lora_name_to_id, m)?)?;
Ryan Olson's avatar
Ryan Olson committed
142
    m.add_function(wrap_pyfunction!(log_message, m)?)?;
143
    m.add_function(wrap_pyfunction!(register_llm, m)?)?;
144
    m.add_function(wrap_pyfunction!(unregister_llm, m)?)?;
145
    m.add_function(wrap_pyfunction!(fetch_llm, m)?)?;
146
147
    m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
    m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
148

149
150
151
152
153
154
155
    m.add_class::<DistributedRuntime>()?;
    m.add_class::<CancellationToken>()?;
    m.add_class::<Namespace>()?;
    m.add_class::<Component>()?;
    m.add_class::<Endpoint>()?;
    m.add_class::<Client>()?;
    m.add_class::<AsyncResponseStream>()?;
156
157
158
    m.add_class::<llm::entrypoint::EntrypointArgs>()?;
    m.add_class::<llm::entrypoint::EngineConfig>()?;
    m.add_class::<llm::entrypoint::EngineType>()?;
159
160
    m.add_class::<llm::entrypoint::RouterConfig>()?;
    m.add_class::<llm::entrypoint::KvRouterConfig>()?;
161
    m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
162
    m.add_class::<llm::model_card::ModelDeploymentCard>()?;
163
    m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
164
    m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
165
166
    m.add_class::<llm::preprocessor::MediaDecoder>()?;
    m.add_class::<llm::preprocessor::MediaFetcher>()?;
167
    m.add_class::<llm::backend::Backend>()?;
168
169
    m.add_class::<llm::kv::OverlapScores>()?;
    m.add_class::<llm::kv::KvIndexer>()?;
170
    m.add_class::<llm::kv::ApproxKvIndexer>()?;
171
    m.add_class::<llm::kv::KvEventPublisher>()?;
Yan Ru Pei's avatar
Yan Ru Pei committed
172
173
    m.add_class::<llm::kv::RadixTree>()?;
    m.add_class::<llm::kv::ZmqKvEventListener>()?;
174
175
    m.add_class::<llm::kv::ZmqKvEventPublisher>()?;
    m.add_class::<llm::kv::ZmqKvEventPublisherConfig>()?;
176
    m.add_class::<llm::kv::KvRecorder>()?;
177
    m.add_class::<llm::lora::LoRADownloader>()?;
178
179
    m.add_class::<http::HttpService>()?;
    m.add_class::<http::HttpAsyncEngine>()?;
180
    m.add_class::<context::Context>()?;
181
    m.add_class::<ModelType>()?;
182
    m.add_class::<ModelInput>()?;
183
184
185
186
    m.add_class::<llm::kv::ForwardPassMetrics>()?;
    m.add_class::<llm::kv::WorkerStats>()?;
    m.add_class::<llm::kv::KvStats>()?;
    m.add_class::<llm::kv::SpecDecodeStats>()?;
187
188
    m.add_class::<llm::kv::KvPushRouter>()?;
    m.add_class::<llm::kv::KvPushRouterStream>()?;
189
    m.add_class::<RouterMode>()?;
190
    m.add_class::<kserve_grpc::KserveGrpcService>()?;
191
    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
192
193
194
    m.add_class::<planner::VirtualConnectorCoordinator>()?;
    m.add_class::<planner::VirtualConnectorClient>()?;
    m.add_class::<planner::PlannerDecision>()?;
195
196

    engine::add_to_module(m)?;
197
    parsers::add_to_module(m)?;
198

199
    m.add_class::<prometheus_metrics::RuntimeMetrics>()?;
200
201
202
203
    let prometheus_metrics = PyModule::new(m.py(), "prometheus_metrics")?;
    prometheus_metrics::add_to_module(&prometheus_metrics)?;
    m.add_submodule(&prometheus_metrics)?;

204
205
206
207
208
209
210
211
212
213
    Ok(())
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}

Ryan Olson's avatar
Ryan Olson committed
214
215
216
217
218
219
220
/// Log a message from Python with file and line info
#[pyfunction]
#[pyo3(text_signature = "(level, message, module, file, line)")]
fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) {
    logging::log_message(level, message, module, file, line);
}

221
222
223
224
225
226
227
/// Generate a deterministic signed int32 ID from a LoRA name using blake3 hash.
#[pyfunction]
#[pyo3(text_signature = "(lora_name)")]
fn lora_name_to_id(lora_name: &str) -> i32 {
    llm_rs::utils::lora_name_to_id(lora_name)
}

228
229
/// Create an engine and attach it to an endpoint to make it visible to the frontend.
/// This is the main way you create a Dynamo worker / backend.
230
231
232
233
234
235
236
///
/// If `lora_name` is provided, this function will publish a LoRA adapter instead of a base model:
/// - LoRA path: v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
/// - Base model path: v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}
///
/// For LoRA mode, both `lora_name` and `base_model_path` must be provided together.
/// Providing only one of them will result in an error.
237
#[pyfunction]
238
#[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None, media_decoder=None, media_fetcher=None, lora_name=None, base_model_path=None))]
239
#[allow(clippy::too_many_arguments)]
240
241
fn register_llm<'p>(
    py: Python<'p>,
242
    model_input: ModelInput,
243
    model_type: ModelType,
244
245
246
    endpoint: Endpoint,
    model_path: &str,
    model_name: Option<&str>,
247
248
    context_length: Option<u32>,
    kv_cache_block_size: Option<u32>,
249
    router_mode: Option<RouterMode>,
250
    migration_limit: u32,
251
    runtime_config: Option<ModelRuntimeConfig>,
252
    user_data: Option<&Bound<'p, PyDict>>,
253
    custom_template_path: Option<&str>,
254
255
    media_decoder: Option<MediaDecoder>,
    media_fetcher: Option<MediaFetcher>,
256
257
    lora_name: Option<&str>,
    base_model_path: Option<&str>,
258
) -> PyResult<Bound<'p, PyAny>> {
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    // Validate Prefill model type requirements
    if model_type.inner == llm_rs::model_type::ModelType::Prefill {
        if !matches!(model_input, ModelInput::Tokens) {
            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
                "ModelType::Prefill requires model_input to be ModelInput::Tokens",
            ));
        }
        if migration_limit != 0 {
            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
                "ModelType::Prefill requires migration_limit to be 0",
            ));
        }
    }

273
274
275
    let model_input = match model_input {
        ModelInput::Text => llm_rs::model_type::ModelInput::Text,
        ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
276
        ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor,
277
278
    };

279
280
    let model_type_obj = model_type.inner;

281
    let inner_path = model_path.to_string();
282
    let model_name = model_name.map(|n| n.to_string());
283
284
285
    let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
    let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());

286
287
288
289
290
291
292
293
294
295
296
297
298
    // Early validation of custom template path
    let custom_template_path_owned = custom_template_path
        .map(|s| {
            let path = PathBuf::from(s);
            if !path.exists() {
                return Err(PyErr::new::<pyo3::exceptions::PyFileNotFoundError, _>(
                    format!("Custom template file does not exist: {}", path.display()),
                ));
            }
            Ok(path)
        })
        .transpose()?;

299
300
301
302
303
304
305
    let user_data_json = user_data
        .map(|dict| pythonize::depythonize(dict))
        .transpose()
        .map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to convert user_data: {}", err))
        })?;

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    // Validate LoRA parameters: both or neither must be provided
    if lora_name.is_some() ^ base_model_path.is_some() {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "lora_name and base_model_path must both be provided together, or neither",
        ));
    }

    // Determine source_path and lora_identifier based on registration mode
    let (source_path, lora_identifier) = match (lora_name, base_model_path) {
        (Some(lora), Some(base)) => (base.to_string(), Some(lora.to_string())),
        _ => (inner_path, None),
    };

    // Model name: use lora name if present, otherwise provided name or default to source path
    let model_name = lora_identifier
        .clone()
        .or(model_name)
        .or_else(|| Some(source_path.clone()));

325
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
326
327
328
        // Resolve the model path (local or fetch from HuggingFace)
        let model_path = if fs::exists(&source_path)? {
            PathBuf::from(&source_path)
329
        } else {
330
            LocalModel::fetch(&source_path, false)
331
332
333
334
                .await
                .map_err(to_pyerr)?
        };

335
336
        let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
        builder
337
            .model_path(model_path)
338
            .model_name(model_name.clone())
339
            .context_length(context_length)
340
            .kv_cache_block_size(kv_cache_block_size)
341
            .router_config(Some(router_config))
342
            .migration_limit(Some(migration_limit))
343
            .runtime_config(runtime_config.unwrap_or_default().inner)
344
            .user_data(user_data_json)
345
346
347
            .custom_template_path(custom_template_path_owned)
            .media_decoder(media_decoder.map(|m| m.inner))
            .media_fetcher(media_fetcher.map(|m| m.inner));
348

349
        let mut local_model = builder.build().await.map_err(to_pyerr)?;
350
        local_model
351
352
353
354
355
356
            .attach(
                &endpoint.inner,
                model_type_obj,
                model_input,
                lora_identifier.as_deref(),
            )
357
358
359
            .await
            .map_err(to_pyerr)?;

360
361
362
363
364
365
        if let Some(lora_name) = lora_identifier {
            tracing::info!("Registered LoRA '{}' MDC", lora_name);
        } else {
            tracing::info!("Registered base model '{:?}' MDC", model_name);
        }

366
367
368
369
        Ok(())
    })
}

370
371
372
373
374
375
376
377
378
379
380
381
382
/// Unregister a Model Deployment Card (MDC) from the service registry
///
/// This removes an LLM deployment from the discovery system.
///
/// # Arguments
///
/// * `endpoint` - The endpoint where the model is registered
/// * `lora_name` - Optional LoRA adapter name (if unregistering a LoRA deployment)
///
/// # MDC Path Format
///
/// - Base model: `v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}`
/// - LoRA model: `v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}`
383
#[pyfunction]
384
385
386
387
388
389
390
391
#[pyo3(signature = (endpoint, lora_name=None))]
fn unregister_llm<'p>(
    py: Python<'p>,
    endpoint: Endpoint,
    lora_name: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> {
    let lora_name_owned = lora_name.map(|s| s.to_string());

392
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
393
394
        // Unified detach method handles both base models and LoRA adapters
        LocalModel::detach_from_endpoint(&endpoint.inner, lora_name_owned.as_deref())
395
396
397
398
399
400
            .await
            .map_err(to_pyerr)?;
        Ok(())
    })
}

401
402
403
404
405
406
407
408
409
410
411
/// Download a model from Hugging Face, returning it's local path
/// Example: `model_path = await fetch_llm("Qwen/Qwen3-0.6B")`
#[pyfunction]
#[pyo3(signature = (remote_name))]
fn fetch_llm<'p>(py: Python<'p>, remote_name: &str) -> PyResult<Bound<'p, PyAny>> {
    let repo = remote_name.to_string();
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        LocalModel::fetch(&repo, false).await.map_err(to_pyerr)
    })
}

412
413
#[pyclass]
#[derive(Clone)]
Ryan Olson's avatar
Ryan Olson committed
414
pub struct DistributedRuntime {
415
416
417
418
    inner: rs::DistributedRuntime,
    event_loop: PyObject,
}

Ryan Olson's avatar
Ryan Olson committed
419
420
impl DistributedRuntime {
    #[allow(dead_code)]
421
    pub(crate) fn inner(&self) -> &rs::DistributedRuntime {
Ryan Olson's avatar
Ryan Olson committed
422
423
424
425
        &self.inner
    }
}

426
#[pyclass]
427
#[derive(Clone)]
428
429
430
431
432
struct CancellationToken {
    inner: rs::CancellationToken,
}

#[pyclass]
433
#[derive(Clone)]
434
435
436
437
438
439
struct Namespace {
    inner: rs::component::Namespace,
    event_loop: PyObject,
}

#[pyclass]
440
#[derive(Clone)]
441
442
443
444
445
446
struct Component {
    inner: rs::component::Component,
    event_loop: PyObject,
}

#[pyclass]
447
#[derive(Clone)]
448
449
450
451
452
453
struct Endpoint {
    inner: rs::component::Endpoint,
    event_loop: PyObject,
}

#[pyclass]
454
#[derive(Clone)]
455
struct Client {
456
    router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>,
457
458
}

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
#[pyclass]
#[derive(Clone, PartialEq)]
struct ModelType {
    inner: llm_rs::model_type::ModelType,
}

#[pymethods]
#[allow(non_upper_case_globals)]
impl ModelType {
    #[classattr]
    const Chat: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Chat,
    };
    #[classattr]
    const Completions: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Completions,
    };
    #[classattr]
    const Embedding: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Embedding,
    };
480
481
482
483
    #[classattr]
    const TensorBased: Self = ModelType {
        inner: llm_rs::model_type::ModelType::TensorBased,
    };
484
485
486
487
    #[classattr]
    const Prefill: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Prefill,
    };
488
489
490
491
492
493
494
495
496
497
498
499

    fn __or__(&self, other: &Self) -> Self {
        ModelType {
            inner: self.inner | other.inner,
        }
    }

    fn __str__(&self) -> String {
        self.inner.to_string()
    }
}

500
501
#[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq)]
502
503
504
enum ModelInput {
    Text = 1,
    Tokens = 2,
505
    Tensor = 3,
506
507
}

508
509
510
#[pymethods]
impl DistributedRuntime {
    #[new]
511
    fn new(event_loop: PyObject, store_kv: String, request_plane: String) -> PyResult<Self> {
512
        let selected_kv_store: kv::Selector = store_kv.parse().map_err(to_pyerr)?;
513
        let request_plane: RequestPlaneMode = request_plane.parse().map_err(to_pyerr)?;
514

515
516
517
        // Try to get existing runtime first, create new Worker only if needed
        // This allows multiple DistributedRuntime instances to share the same tokio runtime
        let runtime = rs::Worker::runtime_from_existing()
518
            .or_else(|_| -> anyhow::Result<rs::Runtime> {
519
520
521
522
                // No existing Worker, create new one
                let worker = rs::Worker::from_settings()?;

                // Initialize pyo3 bridge (only happens once per process)
523
                INIT.get_or_try_init(|| -> anyhow::Result<()> {
524
525
                    let primary = worker.tokio_runtime()?;
                    pyo3_async_runtimes::tokio::init_with_runtime(primary).map_err(|e| {
526
                        anyhow::anyhow!("failed to initialize pyo3 static runtime: {:?}", e)
527
                    })?;
528
                    Ok(())
529
530
                })?;

531
                Ok(worker.runtime().clone())
532
533
            })
            .map_err(to_pyerr)?;
534

535
536
        // Initialize logging in context where tokio runtime is available
        // otel exporter requires it
537
538
539
540
        if std::env::var(env_otlp::OTEL_EXPORT_ENABLED)
            .map(|v| v == "1")
            .unwrap_or(false)
        {
541
542
543
544
            runtime.secondary().block_on(async {
                rs::logging::init();
            });
        }
545

546
547
        let runtime_config = DistributedConfig {
            store_backend: selected_kv_store,
548
549
550
551
552
553
            // We only need NATS here to monitor it's metrics, so only if it's our request plane.
            nats_config: if request_plane.is_nats() {
                Some(dynamo_runtime::transports::nats::ClientOptions::default())
            } else {
                None
            },
554
            request_plane,
555
556
557
558
559
        };
        let inner = runtime
            .secondary()
            .block_on(rs::DistributedRuntime::new(runtime, runtime_config))
            .map_err(to_pyerr)?;
560
561
562
563

        Ok(DistributedRuntime { inner, event_loop })
    }

Ryan Olson's avatar
Ryan Olson committed
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    #[staticmethod]
    fn detached(py: Python) -> PyResult<Self> {
        let rt = rs::Worker::runtime_from_existing().map_err(to_pyerr)?;
        let handle = rt.primary();

        let inner = handle
            .block_on(rs::DistributedRuntime::from_settings(rt))
            .map_err(to_pyerr)?;

        Ok(DistributedRuntime {
            inner,
            event_loop: py.None(),
        })
    }

579
580
581
582
583
584
585
586
    fn namespace(&self, name: String) -> PyResult<Namespace> {
        Ok(Namespace {
            inner: self.inner.namespace(name).map_err(to_pyerr)?,
            event_loop: self.event_loop.clone(),
        })
    }

    fn shutdown(&self) {
587
        self.inner.shutdown();
588
589
590
591
592
    }

    fn event_loop(&self) -> PyObject {
        self.event_loop.clone()
    }
593
594
595
596
597

    fn child_token(&self) -> CancellationToken {
        let inner = self.inner.runtime().child_token();
        CancellationToken { inner }
    }
Richard Huo's avatar
Richard Huo committed
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

    // This is used to pass the DistributedRuntime from the dynamo-runtime bindings
    // to the KVBM bindings, since KVBM cannot directly use the struct from this cdylib.
    // TODO: Create a separate crate "dynamo-python" so that all binding crates can import
    // from it and share the same crate path. This will allow PyO3 to automatically
    // recognize that both bindings use the same PyClass.
    #[pyo3(name = "to_capsule")]
    fn to_capsule<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
        let arc: Arc<rs::DistributedRuntime> = Arc::new(self.inner.clone());
        let weak: Weak<rs::DistributedRuntime> = Arc::downgrade(&arc);

        let name = CString::new("dynamo.runtime.weak").expect("valid capsule name");

        PyCapsule::new(py, weak, Some(name))
    }
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
}

#[pymethods]
impl CancellationToken {
    fn cancel(&self) {
        self.inner.cancel();
    }

    fn cancelled<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let token = self.inner.clone();
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            token.cancelled().await;
            Ok(())
        })
    }
}

#[pymethods]
impl Component {
    fn endpoint(&self, name: String) -> PyResult<Endpoint> {
        let inner = self.inner.endpoint(name);
        Ok(Endpoint {
            inner,
            event_loop: self.event_loop.clone(),
        })
    }

640
641
    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
642
643
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_component(self.inner.clone())
644
    }
645
646
647
648
}

#[pymethods]
impl Endpoint {
649
    #[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None, health_check_payload = None))]
650
651
652
653
    fn serve_endpoint<'p>(
        &self,
        py: Python<'p>,
        generator: PyObject,
654
        graceful_shutdown: Option<bool>,
655
        metrics_labels: Option<Vec<(String, String)>>,
656
        health_check_payload: Option<&Bound<'p, PyDict>>,
657
658
659
660
661
    ) -> PyResult<Bound<'p, PyAny>> {
        let engine = Arc::new(engine::PythonAsyncEngine::new(
            generator,
            self.event_loop.clone(),
        )?);
662
        let ingress = JsonServerStreamingIngress::for_engine(engine.clone()).map_err(to_pyerr)?;
663
664
665
666
667
668
669
670
671
672
673
674
675

        // Convert Python dict to serde_json::Value if provided and validate it's an object
        let health_payload_json = health_check_payload
            .map(|dict| pythonize::depythonize::<serde_json::Value>(dict))
            .transpose()
            .map_err(|err| {
                pyo3::exceptions::PyTypeError::new_err(format!(
                    "Failed to convert health_check_payload: {}",
                    err
                ))
            })?;

        // Require an object/dict
676
677
678
679
680
681
        if let Some(ref payload) = health_payload_json
            && !payload.is_object()
        {
            return Err(pyo3::exceptions::PyTypeError::new_err(
                "health_check_payload must be a JSON object (dict)",
            ));
682
683
684
        }

        let mut builder = self
685
686
687
688
            .inner
            .endpoint_builder()
            .metrics_labels(metrics_labels)
            .handler(ingress);
689
690
691
692
693

        if let Some(payload) = health_payload_json {
            builder = builder.health_check_payload(payload);
        }

694
695
696
        // Register the engine in the local endpoint registry for in-process calls
        builder = builder.register_local_engine(engine).map_err(to_pyerr)?;

697
        let graceful_shutdown = graceful_shutdown.unwrap_or(true);
698
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
699
700
701
702
703
            builder
                .graceful_shutdown(graceful_shutdown)
                .start()
                .await
                .map_err(to_pyerr)?;
704
705
706
707
            Ok(())
        })
    }

708
709
710
    fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let inner = self.inner.clone();
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
711
            let client = inner.client().await.map_err(to_pyerr)?;
712
713
714
715
716
717
            let push_router = rs::pipeline::PushRouter::<
                serde_json::Value,
                RsAnnotated<serde_json::Value>,
            >::from_client(client, Default::default())
            .await
            .map_err(to_pyerr)?;
718
719
720
            Ok(Client {
                router: push_router,
            })
721
722
        })
    }
723

724
725
726
    // Opaque unique ID for this worker. May change over worker lifetime.
    fn connection_id(&self) -> u64 {
        self.inner.drt().connection_id()
727
    }
728
729
730

    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
731
732
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_endpoint(self.inner.clone())
733
    }
734
735
736
737
738
739
740
741
742
743
744
}

#[pymethods]
impl Namespace {
    fn component(&self, name: String) -> PyResult<Component> {
        let inner = self.inner.component(name).map_err(to_pyerr)?;
        Ok(Component {
            inner,
            event_loop: self.event_loop.clone(),
        })
    }
745
746
747

    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
748
749
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_namespace(self.inner.clone())
750
    }
751
752
753
754
}

#[pymethods]
impl Client {
755
756
    /// Get list of current instances.
    /// Replaces endpoint_ids.
757
    fn instance_ids(&self) -> Vec<u64> {
758
        self.router.client.instance_ids()
759
760
    }

761
762
763
    /// Wait for an instance to be available for work.
    /// Replaces wait_for_endpoints.
    fn wait_for_instances<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
764
        let inner = self.router.client.clone();
765
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
766
            inner
767
                .wait_for_instances()
768
                .await
769
                .map(|v| v.into_iter().map(|cei| cei.id()).collect::<Vec<u64>>())
770
                .map_err(to_pyerr)
771
772
773
774
        })
    }

    /// Issue a request to the endpoint using the default routing strategy.
775
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
776
777
778
779
780
    fn generate<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
781
        context: Option<context::Context>,
782
    ) -> PyResult<Bound<'p, PyAny>> {
783
        self.random(py, request, annotated, context)
784
785
786
    }

    /// Send a request to the next endpoint in a round-robin fashion.
787
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
788
789
790
791
792
    fn round_robin<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
793
        context: Option<context::Context>,
794
795
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
796
        let request_ctx = create_request_context(request, &context);
797
798
799
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
800
        let client = self.router.clone();
801
802

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
803
804
            let stream = match context {
                Some(context) => {
805
806
                    // Always instrument with appropriate span (none if no trace context)
                    let span = get_span_for_context(&context, "round_robin");
807
808
809
810
811
                    client
                        .round_robin(request_ctx)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
812
                }
813
                _ => client.round_robin(request_ctx).await.map_err(to_pyerr)?,
814
            };
815
816
817
818
819
820
821
822
823
            tokio::spawn(process_stream(stream, tx));
            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }

    /// Send a request to a random endpoint.
824
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
825
826
827
828
829
    fn random<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
830
        context: Option<context::Context>,
831
832
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
833
        let request_ctx = create_request_context(request, &context);
834
835
836
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
837
        let client = self.router.clone();
838
839

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
840
841
            let stream = match context {
                Some(context) => {
842
                    // Always instrument with appropriate span (none if no trace context)
843
                    let span = get_span_for_context(&context, "random");
844
845
846
847
848
                    client
                        .random(request_ctx)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
849
                }
850
                _ => client.random(request_ctx).await.map_err(to_pyerr)?,
851
            };
852
853
854
855
856
857
858
859
860
            tokio::spawn(process_stream(stream, tx));
            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }

    /// Directly send a request to a specific endpoint.
861
    #[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
862
863
864
865
    fn direct<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
866
        instance_id: u64,
867
        annotated: Option<bool>,
868
        context: Option<context::Context>,
869
870
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
871
        let request_ctx = create_request_context(request, &context);
872
873
874
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
875
        let client = self.router.clone();
876
877

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
878
879
            let stream = match context {
                Some(context) => {
880
                    // Always instrument with appropriate span (none if no trace context)
881
882
                    let span =
                        get_span_for_direct_context(&context, "direct", &instance_id.to_string());
883
884
885
886
887
                    client
                        .direct(request_ctx, instance_id)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
888
889
                }
                _ => client
890
                    .direct(request_ctx, instance_id)
891
892
893
                    .await
                    .map_err(to_pyerr)?,
            };
894
895
896
897
898
899
900
901
902
903
904
905

            tokio::spawn(process_stream(stream, tx));

            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }
}

async fn process_stream(
906
    stream: EngineStream<RsAnnotated<serde_json::Value>>,
907
908
909
910
911
    tx: tokio::sync::mpsc::Sender<RsAnnotated<PyObject>>,
) {
    let mut stream = stream;
    while let Some(response) = stream.next().await {
        // Convert the response to a PyObject using Python's GIL
912
        let annotated: RsAnnotated<serde_json::Value> = response;
913
        let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| {
914
            Python::with_gil(|py| match pythonize::pythonize(py, &data) {
915
916
                Ok(pyobj) => Ok(pyobj.into()),
                Err(e) => Err(e.to_string()),
917
            })
918
919
920
921
922
923
        });

        let is_error = annotated.is_error();

        // Send the PyObject through the channel or log an error
        if let Err(e) = tx.send(annotated).await {
924
            tracing::error!("Failed to send response: {:?}", e);
925
            break;
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        }

        if is_error {
            break;
        }
    }
}

#[pyclass]
struct AsyncResponseStream {
    rx: Arc<Mutex<tokio::sync::mpsc::Receiver<RsAnnotated<PyObject>>>>,
    annotated: bool,
}

#[pymethods]
impl AsyncResponseStream {
    /// This method is required to implement the `AsyncIterator` protocol.
    #[pyo3(name = "__aiter__")]
    fn aiter(slf: PyRef<Self>, py: Python) -> PyResult<Py<PyAny>> {
        slf.into_py_any(py)
    }
    /// This method is required to implement the `AsyncIterator` protocol.
    #[pyo3(name = "__anext__")]
    fn next<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let rx = self.rx.clone();
        let annotated = self.annotated;

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            loop {
                let value = rx.lock().await.recv().await;
                match value {
                    Some(pyobj) => {
                        let pyobj = match pyobj.ok() {
                            Ok(pyobj) => pyobj,
                            Err(e) => {
                                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(e));
                            }
                        };

                        if annotated {
                            let object = Annotated { inner: pyobj };
                            #[allow(deprecated)]
                            let object = Python::with_gil(|py| object.into_py(py));
                            return Ok(object);
                        } else {
                            match pyobj.data {
                                Some(data) => return Ok(data),
                                None => continue,
                            }
                        }
                    }
                    None => return Err(PyStopAsyncIteration::new_err("Stream exhausted")),
                }
            }
        })
    }
}

#[pyclass]
struct Annotated {
    inner: RsAnnotated<PyObject>,
}

#[pymethods]
impl Annotated {
    #[new]
    fn new(data: PyObject) -> Self {
        Annotated {
            inner: RsAnnotated::from_data(data),
        }
    }

    fn is_error(&self) -> bool {
        self.inner.is_error()
    }

    fn data(&self) -> Option<PyObject> {
        self.inner.data.clone()
    }

    fn event(&self) -> Option<String> {
        self.inner.event.clone()
    }

    fn comments(&self) -> Option<Vec<String>> {
        self.inner.comment.clone()
    }

    fn id(&self) -> Option<String> {
        self.inner.id.clone()
    }

    #[pyo3(name = "__repr__")]
    fn _repr(&self, py: Python) -> String {
        let data = self.inner.data.clone().map(|obj| {
            obj.call_method0(py, "__repr__")
                .and_then(|repr_obj| repr_obj.extract::<Py<PyString>>(py))
                .map(|py_str| py_str.to_string_lossy(py).into_owned())
                .unwrap_or_else(|_| "<failed_repr>".to_string())
        });

        format!(
            "Annotated(data={}, event={}, comment={:?}, id={})",
            data.unwrap_or_else(|| "<no_data>".to_string()),
            self.inner.event.as_deref().unwrap_or("None"),
            self.inner.comment.as_deref().unwrap_or(&[]),
            self.inner.id.as_deref().unwrap_or("None")
        )
    }
}