lib.rs 32.5 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use dynamo_llm::local_model::LocalModel;
5
use dynamo_runtime::distributed::{DistributedConfig, RequestPlaneMode};
6
use dynamo_runtime::storage::kv;
7
8
use futures::StreamExt;
use once_cell::sync::OnceCell;
9
use pyo3::IntoPyObjectExt;
10
use pyo3::exceptions::PyStopAsyncIteration;
Richard Huo's avatar
Richard Huo committed
11
use pyo3::types::PyCapsule;
12
use pyo3::types::{PyDict, PyString};
13
14
use pyo3::{exceptions::PyException, prelude::*};
use rs::pipeline::network::Ingress;
Richard Huo's avatar
Richard Huo committed
15
use std::ffi::CString;
16
use std::fs;
17
use std::path::PathBuf;
Richard Huo's avatar
Richard Huo committed
18
19
20
21
use std::{
    fmt::Display,
    sync::{Arc, Weak},
};
22
use tokio::sync::Mutex;
23
use tracing::Instrument;
24

25
use dynamo_runtime::config::environment_names::logging::otlp as env_otlp;
Neelay Shah's avatar
Neelay Shah committed
26
use dynamo_runtime::{
Ryan Olson's avatar
Ryan Olson committed
27
    self as rs, logging,
28
    pipeline::{
29
        AsyncEngineContextProvider, EngineStream, ManyOut, SingleIn, context::Context as RsContext,
30
        network::egress::push_router::RouterMode as RsRouterMode,
31
    },
32
    protocols::annotated::Annotated as RsAnnotated,
33
    traits::DistributedRuntimeProvider,
34
35
};

Neelay Shah's avatar
Neelay Shah committed
36
use dynamo_llm::{self as llm_rs};
37
38
use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig};

39
use crate::llm::local_model::ModelRuntimeConfig;
40
use crate::llm::preprocessor::{MediaDecoder, MediaFetcher};
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
pub enum RouterMode {
    RoundRobin,
    Random,
    KV,
}

impl From<RouterMode> for RsRouterMode {
    fn from(mode: RouterMode) -> Self {
        match mode {
            RouterMode::RoundRobin => Self::RoundRobin,
            RouterMode::Random => Self::Random,
            RouterMode::KV => Self::KV,
        }
    }
}
59

60
mod context;
61
mod engine;
62
mod http;
63
mod kserve_grpc;
64
mod llm;
65
mod parsers;
66
mod planner;
67
mod prometheus_metrics;
68
69
70
71
72
73
74
75

type JsonServerStreamingIngress =
    Ingress<SingleIn<serde_json::Value>, ManyOut<RsAnnotated<serde_json::Value>>>;

static INIT: OnceCell<()> = OnceCell::new();

const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true);

76
77
// Helper to get appropriate span for instrumentation - always emit spans
fn get_span_for_context(context: &context::Context, operation: &str) -> tracing::Span {
78
79
80
81
82
83
    logging::make_client_request_span(
        operation,
        context.inner().id(),
        context.trace_context(),
        None,
    )
84
85
86
87
88
89
90
91
}

// Helper to create span for direct method with instance_id
fn get_span_for_direct_context(
    context: &context::Context,
    operation: &str,
    instance_id: &str,
) -> tracing::Span {
92
93
94
95
96
97
    logging::make_client_request_span(
        operation,
        context.inner().id(),
        context.trace_context(),
        Some(instance_id),
    )
98
99
}

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Helper to create request context with proper linking and cancellation handling
fn create_request_context(
    request: serde_json::Value,
    parent_ctx: &Option<context::Context>,
) -> RsContext<serde_json::Value> {
    match parent_ctx {
        // If there is a parent context, link the request as a child context of it
        Some(parent_ctx) => {
            let child_ctx = RsContext::with_id(request, parent_ctx.inner().id().to_string());
            parent_ctx.inner().link_child(child_ctx.context());
            if parent_ctx.inner().is_stopped() || parent_ctx.inner().is_killed() {
                // Let the server handle the cancellation for now since not all backends are
                // properly handling request exceptions
                // TODO: (DIS-830) Return an error if context is cancelled
                child_ctx.context().stop_generating();
            }
            child_ctx
        }
        // Otherwise if there is no parent context, use the request as-is
        _ => request.into(),
    }
}

123
124
125
126
127
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
128
    // Initialize logging early unless OTEL export is enabled (which requires tokio runtime)
129
130
131
132
    if std::env::var(env_otlp::OTEL_EXPORT_ENABLED)
        .map(|v| v == "1")
        .unwrap_or(false)
    {
133
        eprintln!(
134
            "Warning: OTEL_EXPORT_ENABLED detected. Logging initialization deferred until runtime is available. Early logs may be dropped."
135
136
137
138
139
        );
    } else {
        rs::logging::init();
    }

Yan Ru Pei's avatar
Yan Ru Pei committed
140
    m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?;
141
    m.add_function(wrap_pyfunction!(lora_name_to_id, m)?)?;
Ryan Olson's avatar
Ryan Olson committed
142
    m.add_function(wrap_pyfunction!(log_message, m)?)?;
143
    m.add_function(wrap_pyfunction!(register_llm, m)?)?;
144
    m.add_function(wrap_pyfunction!(unregister_llm, m)?)?;
145
    m.add_function(wrap_pyfunction!(fetch_llm, m)?)?;
146
147
    m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
    m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
148

149
150
151
152
153
154
155
    m.add_class::<DistributedRuntime>()?;
    m.add_class::<CancellationToken>()?;
    m.add_class::<Namespace>()?;
    m.add_class::<Component>()?;
    m.add_class::<Endpoint>()?;
    m.add_class::<Client>()?;
    m.add_class::<AsyncResponseStream>()?;
156
157
158
    m.add_class::<llm::entrypoint::EntrypointArgs>()?;
    m.add_class::<llm::entrypoint::EngineConfig>()?;
    m.add_class::<llm::entrypoint::EngineType>()?;
159
160
    m.add_class::<llm::entrypoint::RouterConfig>()?;
    m.add_class::<llm::entrypoint::KvRouterConfig>()?;
161
    m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
162
    m.add_class::<llm::model_card::ModelDeploymentCard>()?;
163
    m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
164
    m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
165
166
    m.add_class::<llm::preprocessor::MediaDecoder>()?;
    m.add_class::<llm::preprocessor::MediaFetcher>()?;
167
    m.add_class::<llm::backend::Backend>()?;
168
169
    m.add_class::<llm::kv::OverlapScores>()?;
    m.add_class::<llm::kv::KvIndexer>()?;
170
    m.add_class::<llm::kv::ApproxKvIndexer>()?;
171
    m.add_class::<llm::kv::KvEventPublisher>()?;
Yan Ru Pei's avatar
Yan Ru Pei committed
172
173
    m.add_class::<llm::kv::RadixTree>()?;
    m.add_class::<llm::kv::ZmqKvEventListener>()?;
174
175
    m.add_class::<llm::kv::ZmqKvEventPublisher>()?;
    m.add_class::<llm::kv::ZmqKvEventPublisherConfig>()?;
176
    m.add_class::<llm::kv::KvRecorder>()?;
177
    m.add_class::<llm::lora::LoRADownloader>()?;
178
179
    m.add_class::<http::HttpService>()?;
    m.add_class::<http::HttpAsyncEngine>()?;
180
    m.add_class::<context::Context>()?;
181
    m.add_class::<ModelType>()?;
182
    m.add_class::<ModelInput>()?;
183
184
185
186
    m.add_class::<llm::kv::ForwardPassMetrics>()?;
    m.add_class::<llm::kv::WorkerStats>()?;
    m.add_class::<llm::kv::KvStats>()?;
    m.add_class::<llm::kv::SpecDecodeStats>()?;
187
188
    m.add_class::<llm::kv::KvPushRouter>()?;
    m.add_class::<llm::kv::KvPushRouterStream>()?;
189
    m.add_class::<RouterMode>()?;
190
    m.add_class::<kserve_grpc::KserveGrpcService>()?;
191
    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
192
193
194
    m.add_class::<planner::VirtualConnectorCoordinator>()?;
    m.add_class::<planner::VirtualConnectorClient>()?;
    m.add_class::<planner::PlannerDecision>()?;
195
196

    engine::add_to_module(m)?;
197
    parsers::add_to_module(m)?;
198

199
    m.add_class::<prometheus_metrics::RuntimeMetrics>()?;
200
201
202
203
    let prometheus_metrics = PyModule::new(m.py(), "prometheus_metrics")?;
    prometheus_metrics::add_to_module(&prometheus_metrics)?;
    m.add_submodule(&prometheus_metrics)?;

204
205
206
207
208
209
210
211
212
213
    Ok(())
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}

Ryan Olson's avatar
Ryan Olson committed
214
215
216
217
218
219
220
/// Log a message from Python with file and line info
#[pyfunction]
#[pyo3(text_signature = "(level, message, module, file, line)")]
fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) {
    logging::log_message(level, message, module, file, line);
}

221
222
223
224
225
226
227
/// Generate a deterministic signed int32 ID from a LoRA name using blake3 hash.
#[pyfunction]
#[pyo3(text_signature = "(lora_name)")]
fn lora_name_to_id(lora_name: &str) -> i32 {
    llm_rs::utils::lora_name_to_id(lora_name)
}

228
229
/// Create an engine and attach it to an endpoint to make it visible to the frontend.
/// This is the main way you create a Dynamo worker / backend.
230
#[pyfunction]
231
#[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None, media_decoder=None, media_fetcher=None))]
232
#[allow(clippy::too_many_arguments)]
233
234
fn register_llm<'p>(
    py: Python<'p>,
235
    model_input: ModelInput,
236
    model_type: ModelType,
237
238
239
    endpoint: Endpoint,
    model_path: &str,
    model_name: Option<&str>,
240
241
    context_length: Option<u32>,
    kv_cache_block_size: Option<u32>,
242
    router_mode: Option<RouterMode>,
243
    migration_limit: u32,
244
    runtime_config: Option<ModelRuntimeConfig>,
245
    user_data: Option<&Bound<'p, PyDict>>,
246
    custom_template_path: Option<&str>,
247
248
    media_decoder: Option<MediaDecoder>,
    media_fetcher: Option<MediaFetcher>,
249
) -> PyResult<Bound<'p, PyAny>> {
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    // Validate Prefill model type requirements
    if model_type.inner == llm_rs::model_type::ModelType::Prefill {
        if !matches!(model_input, ModelInput::Tokens) {
            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
                "ModelType::Prefill requires model_input to be ModelInput::Tokens",
            ));
        }
        if migration_limit != 0 {
            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
                "ModelType::Prefill requires migration_limit to be 0",
            ));
        }
    }

264
265
266
    let model_input = match model_input {
        ModelInput::Text => llm_rs::model_type::ModelInput::Text,
        ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
267
        ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor,
268
269
    };

270
271
    let model_type_obj = model_type.inner;

272
    let inner_path = model_path.to_string();
273
    let mut model_name = model_name.map(|n| n.to_string());
274
275
276
    let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
    let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());

277
278
279
280
281
282
283
284
285
286
287
288
289
    // Early validation of custom template path
    let custom_template_path_owned = custom_template_path
        .map(|s| {
            let path = PathBuf::from(s);
            if !path.exists() {
                return Err(PyErr::new::<pyo3::exceptions::PyFileNotFoundError, _>(
                    format!("Custom template file does not exist: {}", path.display()),
                ));
            }
            Ok(path)
        })
        .transpose()?;

290
291
292
293
294
295
296
    let user_data_json = user_data
        .map(|dict| pythonize::depythonize(dict))
        .transpose()
        .map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to convert user_data: {}", err))
        })?;

297
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
298
299
300
301
302
303
304
305
306
307
308
309
310
        let model_path = if fs::exists(&inner_path)? {
            PathBuf::from(inner_path)
        } else {
            // Preserve the model name
            if model_name.is_none() {
                model_name = Some(inner_path.clone());
            }
            // Likely it's a Hugging Face repo, download it
            LocalModel::fetch(&inner_path, false)
                .await
                .map_err(to_pyerr)?
        };

311
312
        let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
        builder
313
            .model_path(model_path)
314
315
            .model_name(model_name)
            .context_length(context_length)
316
            .kv_cache_block_size(kv_cache_block_size)
317
            .router_config(Some(router_config))
318
            .migration_limit(Some(migration_limit))
319
            .runtime_config(runtime_config.unwrap_or_default().inner)
320
            .user_data(user_data_json)
321
322
323
            .custom_template_path(custom_template_path_owned)
            .media_decoder(media_decoder.map(|m| m.inner))
            .media_fetcher(media_fetcher.map(|m| m.inner));
324
        // Load the ModelDeploymentCard
325
        let mut local_model = builder.build().await.map_err(to_pyerr)?;
326
        // Advertise ourself so ingress can find us
327
        local_model
328
            .attach(&endpoint.inner, model_type_obj, model_input)
329
330
331
332
333
334
335
            .await
            .map_err(to_pyerr)?;

        Ok(())
    })
}

336
337
338
339
340
341
342
343
344
345
346
347
/// Unregister a model from the endpoint.
#[pyfunction]
#[pyo3(signature = (endpoint))]
fn unregister_llm<'p>(py: Python<'p>, endpoint: Endpoint) -> PyResult<Bound<'p, PyAny>> {
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        LocalModel::detach_model_from_endpoint(&endpoint.inner)
            .await
            .map_err(to_pyerr)?;
        Ok(())
    })
}

348
349
350
351
352
353
354
355
356
357
358
/// Download a model from Hugging Face, returning it's local path
/// Example: `model_path = await fetch_llm("Qwen/Qwen3-0.6B")`
#[pyfunction]
#[pyo3(signature = (remote_name))]
fn fetch_llm<'p>(py: Python<'p>, remote_name: &str) -> PyResult<Bound<'p, PyAny>> {
    let repo = remote_name.to_string();
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        LocalModel::fetch(&repo, false).await.map_err(to_pyerr)
    })
}

359
360
#[pyclass]
#[derive(Clone)]
Ryan Olson's avatar
Ryan Olson committed
361
pub struct DistributedRuntime {
362
363
364
365
    inner: rs::DistributedRuntime,
    event_loop: PyObject,
}

Ryan Olson's avatar
Ryan Olson committed
366
367
impl DistributedRuntime {
    #[allow(dead_code)]
368
    pub(crate) fn inner(&self) -> &rs::DistributedRuntime {
Ryan Olson's avatar
Ryan Olson committed
369
370
371
372
        &self.inner
    }
}

373
#[pyclass]
374
#[derive(Clone)]
375
376
377
378
379
struct CancellationToken {
    inner: rs::CancellationToken,
}

#[pyclass]
380
#[derive(Clone)]
381
382
383
384
385
386
struct Namespace {
    inner: rs::component::Namespace,
    event_loop: PyObject,
}

#[pyclass]
387
#[derive(Clone)]
388
389
390
391
392
393
struct Component {
    inner: rs::component::Component,
    event_loop: PyObject,
}

#[pyclass]
394
#[derive(Clone)]
395
396
397
398
399
400
struct Endpoint {
    inner: rs::component::Endpoint,
    event_loop: PyObject,
}

#[pyclass]
401
#[derive(Clone)]
402
struct Client {
403
    router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>,
404
405
}

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
#[pyclass]
#[derive(Clone, PartialEq)]
struct ModelType {
    inner: llm_rs::model_type::ModelType,
}

#[pymethods]
#[allow(non_upper_case_globals)]
impl ModelType {
    #[classattr]
    const Chat: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Chat,
    };
    #[classattr]
    const Completions: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Completions,
    };
    #[classattr]
    const Embedding: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Embedding,
    };
427
428
429
430
    #[classattr]
    const TensorBased: Self = ModelType {
        inner: llm_rs::model_type::ModelType::TensorBased,
    };
431
432
433
434
    #[classattr]
    const Prefill: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Prefill,
    };
435
436
437
438
439
440
441
442
443
444
445
446

    fn __or__(&self, other: &Self) -> Self {
        ModelType {
            inner: self.inner | other.inner,
        }
    }

    fn __str__(&self) -> String {
        self.inner.to_string()
    }
}

447
448
#[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq)]
449
450
451
enum ModelInput {
    Text = 1,
    Tokens = 2,
452
    Tensor = 3,
453
454
}

455
456
457
#[pymethods]
impl DistributedRuntime {
    #[new]
458
    fn new(event_loop: PyObject, store_kv: String, request_plane: String) -> PyResult<Self> {
459
        let selected_kv_store: kv::Selector = store_kv.parse().map_err(to_pyerr)?;
460
        let request_plane: RequestPlaneMode = request_plane.parse().map_err(to_pyerr)?;
461

462
463
464
        // Try to get existing runtime first, create new Worker only if needed
        // This allows multiple DistributedRuntime instances to share the same tokio runtime
        let runtime = rs::Worker::runtime_from_existing()
465
            .or_else(|_| -> anyhow::Result<rs::Runtime> {
466
467
468
469
                // No existing Worker, create new one
                let worker = rs::Worker::from_settings()?;

                // Initialize pyo3 bridge (only happens once per process)
470
                INIT.get_or_try_init(|| -> anyhow::Result<()> {
471
472
                    let primary = worker.tokio_runtime()?;
                    pyo3_async_runtimes::tokio::init_with_runtime(primary).map_err(|e| {
473
                        anyhow::anyhow!("failed to initialize pyo3 static runtime: {:?}", e)
474
                    })?;
475
                    Ok(())
476
477
                })?;

478
                Ok(worker.runtime().clone())
479
480
            })
            .map_err(to_pyerr)?;
481

482
483
        // Initialize logging in context where tokio runtime is available
        // otel exporter requires it
484
485
486
487
        if std::env::var(env_otlp::OTEL_EXPORT_ENABLED)
            .map(|v| v == "1")
            .unwrap_or(false)
        {
488
489
490
491
            runtime.secondary().block_on(async {
                rs::logging::init();
            });
        }
492

493
494
        let runtime_config = DistributedConfig {
            store_backend: selected_kv_store,
495
496
497
498
499
500
            // We only need NATS here to monitor it's metrics, so only if it's our request plane.
            nats_config: if request_plane.is_nats() {
                Some(dynamo_runtime::transports::nats::ClientOptions::default())
            } else {
                None
            },
501
            request_plane,
502
503
504
505
506
        };
        let inner = runtime
            .secondary()
            .block_on(rs::DistributedRuntime::new(runtime, runtime_config))
            .map_err(to_pyerr)?;
507
508
509
510

        Ok(DistributedRuntime { inner, event_loop })
    }

Ryan Olson's avatar
Ryan Olson committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    #[staticmethod]
    fn detached(py: Python) -> PyResult<Self> {
        let rt = rs::Worker::runtime_from_existing().map_err(to_pyerr)?;
        let handle = rt.primary();

        let inner = handle
            .block_on(rs::DistributedRuntime::from_settings(rt))
            .map_err(to_pyerr)?;

        Ok(DistributedRuntime {
            inner,
            event_loop: py.None(),
        })
    }

526
527
528
529
530
531
532
533
    fn namespace(&self, name: String) -> PyResult<Namespace> {
        Ok(Namespace {
            inner: self.inner.namespace(name).map_err(to_pyerr)?,
            event_loop: self.event_loop.clone(),
        })
    }

    fn shutdown(&self) {
534
        self.inner.shutdown();
535
536
537
538
539
    }

    fn event_loop(&self) -> PyObject {
        self.event_loop.clone()
    }
540
541
542
543
544

    fn child_token(&self) -> CancellationToken {
        let inner = self.inner.runtime().child_token();
        CancellationToken { inner }
    }
Richard Huo's avatar
Richard Huo committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559

    // This is used to pass the DistributedRuntime from the dynamo-runtime bindings
    // to the KVBM bindings, since KVBM cannot directly use the struct from this cdylib.
    // TODO: Create a separate crate "dynamo-python" so that all binding crates can import
    // from it and share the same crate path. This will allow PyO3 to automatically
    // recognize that both bindings use the same PyClass.
    #[pyo3(name = "to_capsule")]
    fn to_capsule<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
        let arc: Arc<rs::DistributedRuntime> = Arc::new(self.inner.clone());
        let weak: Weak<rs::DistributedRuntime> = Arc::downgrade(&arc);

        let name = CString::new("dynamo.runtime.weak").expect("valid capsule name");

        PyCapsule::new(py, weak, Some(name))
    }
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
}

#[pymethods]
impl CancellationToken {
    fn cancel(&self) {
        self.inner.cancel();
    }

    fn cancelled<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let token = self.inner.clone();
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            token.cancelled().await;
            Ok(())
        })
    }
}

#[pymethods]
impl Component {
    fn endpoint(&self, name: String) -> PyResult<Endpoint> {
        let inner = self.inner.endpoint(name);
        Ok(Endpoint {
            inner,
            event_loop: self.event_loop.clone(),
        })
    }

587
588
    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
589
590
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_component(self.inner.clone())
591
    }
592
593
594
595
}

#[pymethods]
impl Endpoint {
596
    #[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None, health_check_payload = None))]
597
598
599
600
    fn serve_endpoint<'p>(
        &self,
        py: Python<'p>,
        generator: PyObject,
601
        graceful_shutdown: Option<bool>,
602
        metrics_labels: Option<Vec<(String, String)>>,
603
        health_check_payload: Option<&Bound<'p, PyDict>>,
604
605
606
607
608
609
    ) -> PyResult<Bound<'p, PyAny>> {
        let engine = Arc::new(engine::PythonAsyncEngine::new(
            generator,
            self.event_loop.clone(),
        )?);
        let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
610
611
612
613
614
615
616
617
618
619
620
621
622

        // Convert Python dict to serde_json::Value if provided and validate it's an object
        let health_payload_json = health_check_payload
            .map(|dict| pythonize::depythonize::<serde_json::Value>(dict))
            .transpose()
            .map_err(|err| {
                pyo3::exceptions::PyTypeError::new_err(format!(
                    "Failed to convert health_check_payload: {}",
                    err
                ))
            })?;

        // Require an object/dict
623
624
625
626
627
628
        if let Some(ref payload) = health_payload_json
            && !payload.is_object()
        {
            return Err(pyo3::exceptions::PyTypeError::new_err(
                "health_check_payload must be a JSON object (dict)",
            ));
629
630
631
        }

        let mut builder = self
632
633
634
635
            .inner
            .endpoint_builder()
            .metrics_labels(metrics_labels)
            .handler(ingress);
636
637
638
639
640

        if let Some(payload) = health_payload_json {
            builder = builder.health_check_payload(payload);
        }

641
        let graceful_shutdown = graceful_shutdown.unwrap_or(true);
642
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
643
644
645
646
647
            builder
                .graceful_shutdown(graceful_shutdown)
                .start()
                .await
                .map_err(to_pyerr)?;
648
649
650
651
            Ok(())
        })
    }

652
653
654
    fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let inner = self.inner.clone();
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
655
            let client = inner.client().await.map_err(to_pyerr)?;
656
657
658
659
660
661
            let push_router = rs::pipeline::PushRouter::<
                serde_json::Value,
                RsAnnotated<serde_json::Value>,
            >::from_client(client, Default::default())
            .await
            .map_err(to_pyerr)?;
662
663
664
            Ok(Client {
                router: push_router,
            })
665
666
        })
    }
667

668
669
670
    // Opaque unique ID for this worker. May change over worker lifetime.
    fn connection_id(&self) -> u64 {
        self.inner.drt().connection_id()
671
    }
672
673
674

    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
675
676
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_endpoint(self.inner.clone())
677
    }
678
679
680
681
682
683
684
685
686
687
688
}

#[pymethods]
impl Namespace {
    fn component(&self, name: String) -> PyResult<Component> {
        let inner = self.inner.component(name).map_err(to_pyerr)?;
        Ok(Component {
            inner,
            event_loop: self.event_loop.clone(),
        })
    }
689
690
691

    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
692
693
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_namespace(self.inner.clone())
694
    }
695
696
697
698
}

#[pymethods]
impl Client {
699
700
    /// Get list of current instances.
    /// Replaces endpoint_ids.
701
    fn instance_ids(&self) -> Vec<u64> {
702
        self.router.client.instance_ids()
703
704
    }

705
706
707
    /// Wait for an instance to be available for work.
    /// Replaces wait_for_endpoints.
    fn wait_for_instances<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
708
        let inner = self.router.client.clone();
709
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
710
            inner
711
                .wait_for_instances()
712
                .await
713
                .map(|v| v.into_iter().map(|cei| cei.id()).collect::<Vec<u64>>())
714
                .map_err(to_pyerr)
715
716
717
718
        })
    }

    /// Issue a request to the endpoint using the default routing strategy.
719
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
720
721
722
723
724
    fn generate<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
725
        context: Option<context::Context>,
726
    ) -> PyResult<Bound<'p, PyAny>> {
727
        self.random(py, request, annotated, context)
728
729
730
    }

    /// Send a request to the next endpoint in a round-robin fashion.
731
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
732
733
734
735
736
    fn round_robin<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
737
        context: Option<context::Context>,
738
739
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
740
        let request_ctx = create_request_context(request, &context);
741
742
743
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
744
        let client = self.router.clone();
745
746

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
747
748
            let stream = match context {
                Some(context) => {
749
750
                    // Always instrument with appropriate span (none if no trace context)
                    let span = get_span_for_context(&context, "round_robin");
751
752
753
754
755
                    client
                        .round_robin(request_ctx)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
756
                }
757
                _ => client.round_robin(request_ctx).await.map_err(to_pyerr)?,
758
            };
759
760
761
762
763
764
765
766
767
            tokio::spawn(process_stream(stream, tx));
            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }

    /// Send a request to a random endpoint.
768
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
769
770
771
772
773
    fn random<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
774
        context: Option<context::Context>,
775
776
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
777
        let request_ctx = create_request_context(request, &context);
778
779
780
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
781
        let client = self.router.clone();
782
783

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
784
785
            let stream = match context {
                Some(context) => {
786
                    // Always instrument with appropriate span (none if no trace context)
787
                    let span = get_span_for_context(&context, "random");
788
789
790
791
792
                    client
                        .random(request_ctx)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
793
                }
794
                _ => client.random(request_ctx).await.map_err(to_pyerr)?,
795
            };
796
797
798
799
800
801
802
803
804
            tokio::spawn(process_stream(stream, tx));
            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }

    /// Directly send a request to a specific endpoint.
805
    #[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
806
807
808
809
    fn direct<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
810
        instance_id: u64,
811
        annotated: Option<bool>,
812
        context: Option<context::Context>,
813
814
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
815
        let request_ctx = create_request_context(request, &context);
816
817
818
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
819
        let client = self.router.clone();
820
821

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
822
823
            let stream = match context {
                Some(context) => {
824
                    // Always instrument with appropriate span (none if no trace context)
825
826
                    let span =
                        get_span_for_direct_context(&context, "direct", &instance_id.to_string());
827
828
829
830
831
                    client
                        .direct(request_ctx, instance_id)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
832
833
                }
                _ => client
834
                    .direct(request_ctx, instance_id)
835
836
837
                    .await
                    .map_err(to_pyerr)?,
            };
838
839
840
841
842
843
844
845
846
847
848
849

            tokio::spawn(process_stream(stream, tx));

            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }
}

async fn process_stream(
850
    stream: EngineStream<RsAnnotated<serde_json::Value>>,
851
852
853
854
855
    tx: tokio::sync::mpsc::Sender<RsAnnotated<PyObject>>,
) {
    let mut stream = stream;
    while let Some(response) = stream.next().await {
        // Convert the response to a PyObject using Python's GIL
856
        let annotated: RsAnnotated<serde_json::Value> = response;
857
        let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| {
858
            Python::with_gil(|py| match pythonize::pythonize(py, &data) {
859
860
                Ok(pyobj) => Ok(pyobj.into()),
                Err(e) => Err(e.to_string()),
861
            })
862
863
864
865
866
867
        });

        let is_error = annotated.is_error();

        // Send the PyObject through the channel or log an error
        if let Err(e) = tx.send(annotated).await {
868
            tracing::error!("Failed to send response: {:?}", e);
869
            break;
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
        }

        if is_error {
            break;
        }
    }
}

#[pyclass]
struct AsyncResponseStream {
    rx: Arc<Mutex<tokio::sync::mpsc::Receiver<RsAnnotated<PyObject>>>>,
    annotated: bool,
}

#[pymethods]
impl AsyncResponseStream {
    /// This method is required to implement the `AsyncIterator` protocol.
    #[pyo3(name = "__aiter__")]
    fn aiter(slf: PyRef<Self>, py: Python) -> PyResult<Py<PyAny>> {
        slf.into_py_any(py)
    }
    /// This method is required to implement the `AsyncIterator` protocol.
    #[pyo3(name = "__anext__")]
    fn next<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let rx = self.rx.clone();
        let annotated = self.annotated;

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            loop {
                let value = rx.lock().await.recv().await;
                match value {
                    Some(pyobj) => {
                        let pyobj = match pyobj.ok() {
                            Ok(pyobj) => pyobj,
                            Err(e) => {
                                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(e));
                            }
                        };

                        if annotated {
                            let object = Annotated { inner: pyobj };
                            #[allow(deprecated)]
                            let object = Python::with_gil(|py| object.into_py(py));
                            return Ok(object);
                        } else {
                            match pyobj.data {
                                Some(data) => return Ok(data),
                                None => continue,
                            }
                        }
                    }
                    None => return Err(PyStopAsyncIteration::new_err("Stream exhausted")),
                }
            }
        })
    }
}

#[pyclass]
struct Annotated {
    inner: RsAnnotated<PyObject>,
}

#[pymethods]
impl Annotated {
    #[new]
    fn new(data: PyObject) -> Self {
        Annotated {
            inner: RsAnnotated::from_data(data),
        }
    }

    fn is_error(&self) -> bool {
        self.inner.is_error()
    }

    fn data(&self) -> Option<PyObject> {
        self.inner.data.clone()
    }

    fn event(&self) -> Option<String> {
        self.inner.event.clone()
    }

    fn comments(&self) -> Option<Vec<String>> {
        self.inner.comment.clone()
    }

    fn id(&self) -> Option<String> {
        self.inner.id.clone()
    }

    #[pyo3(name = "__repr__")]
    fn _repr(&self, py: Python) -> String {
        let data = self.inner.data.clone().map(|obj| {
            obj.call_method0(py, "__repr__")
                .and_then(|repr_obj| repr_obj.extract::<Py<PyString>>(py))
                .map(|py_str| py_str.to_string_lossy(py).into_owned())
                .unwrap_or_else(|_| "<failed_repr>".to_string())
        });

        format!(
            "Annotated(data={}, event={}, comment={:?}, id={})",
            data.unwrap_or_else(|| "<no_data>".to_string()),
            self.inner.event.as_deref().unwrap_or("None"),
            self.inner.comment.as_deref().unwrap_or(&[]),
            self.inner.id.as_deref().unwrap_or("None")
        )
    }
}