lib.rs 32.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use dynamo_llm::local_model::LocalModel;
5
use dynamo_runtime::distributed::{DistributedConfig, RequestPlaneMode};
6
use dynamo_runtime::storage::key_value_store::KeyValueStoreSelect;
7
8
use futures::StreamExt;
use once_cell::sync::OnceCell;
9
use pyo3::IntoPyObjectExt;
10
use pyo3::exceptions::PyStopAsyncIteration;
Richard Huo's avatar
Richard Huo committed
11
use pyo3::types::PyCapsule;
12
use pyo3::types::{PyDict, PyString};
13
14
use pyo3::{exceptions::PyException, prelude::*};
use rs::pipeline::network::Ingress;
Richard Huo's avatar
Richard Huo committed
15
use std::ffi::CString;
16
use std::fs;
17
use std::path::PathBuf;
Richard Huo's avatar
Richard Huo committed
18
19
20
21
use std::{
    fmt::Display,
    sync::{Arc, Weak},
};
22
use tokio::sync::Mutex;
23
use tracing::Instrument;
24

25
use dynamo_runtime::config::environment_names::logging::otlp as env_otlp;
Neelay Shah's avatar
Neelay Shah committed
26
use dynamo_runtime::{
Ryan Olson's avatar
Ryan Olson committed
27
    self as rs, logging,
28
    pipeline::{
29
        AsyncEngineContextProvider, EngineStream, ManyOut, SingleIn, context::Context as RsContext,
30
        network::egress::push_router::RouterMode as RsRouterMode,
31
    },
32
    protocols::annotated::Annotated as RsAnnotated,
33
    traits::DistributedRuntimeProvider,
34
35
};

Neelay Shah's avatar
Neelay Shah committed
36
use dynamo_llm::{self as llm_rs};
37
38
use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig};

39
use crate::llm::local_model::ModelRuntimeConfig;
40
use crate::llm::preprocessor::{MediaDecoder, MediaFetcher};
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
pub enum RouterMode {
    RoundRobin,
    Random,
    KV,
}

impl From<RouterMode> for RsRouterMode {
    fn from(mode: RouterMode) -> Self {
        match mode {
            RouterMode::RoundRobin => Self::RoundRobin,
            RouterMode::Random => Self::Random,
            RouterMode::KV => Self::KV,
        }
    }
}
59

60
mod context;
61
mod engine;
62
mod http;
63
mod kserve_grpc;
64
mod llm;
65
mod parsers;
66
mod planner;
67
mod prometheus_metrics;
68
69
70
71
72
73
74
75

type JsonServerStreamingIngress =
    Ingress<SingleIn<serde_json::Value>, ManyOut<RsAnnotated<serde_json::Value>>>;

static INIT: OnceCell<()> = OnceCell::new();

const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true);

76
77
// Helper to get appropriate span for instrumentation - always emit spans
fn get_span_for_context(context: &context::Context, operation: &str) -> tracing::Span {
78
79
80
81
82
83
    logging::make_client_request_span(
        operation,
        context.inner().id(),
        context.trace_context(),
        None,
    )
84
85
86
87
88
89
90
91
}

// Helper to create span for direct method with instance_id
fn get_span_for_direct_context(
    context: &context::Context,
    operation: &str,
    instance_id: &str,
) -> tracing::Span {
92
93
94
95
96
97
    logging::make_client_request_span(
        operation,
        context.inner().id(),
        context.trace_context(),
        Some(instance_id),
    )
98
99
}

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Helper to create request context with proper linking and cancellation handling
fn create_request_context(
    request: serde_json::Value,
    parent_ctx: &Option<context::Context>,
) -> RsContext<serde_json::Value> {
    match parent_ctx {
        // If there is a parent context, link the request as a child context of it
        Some(parent_ctx) => {
            let child_ctx = RsContext::with_id(request, parent_ctx.inner().id().to_string());
            parent_ctx.inner().link_child(child_ctx.context());
            if parent_ctx.inner().is_stopped() || parent_ctx.inner().is_killed() {
                // Let the server handle the cancellation for now since not all backends are
                // properly handling request exceptions
                // TODO: (DIS-830) Return an error if context is cancelled
                child_ctx.context().stop_generating();
            }
            child_ctx
        }
        // Otherwise if there is no parent context, use the request as-is
        _ => request.into(),
    }
}

123
124
125
126
127
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
128
    // Initialize logging early unless OTEL export is enabled (which requires tokio runtime)
129
130
131
132
    if std::env::var(env_otlp::OTEL_EXPORT_ENABLED)
        .map(|v| v == "1")
        .unwrap_or(false)
    {
133
        eprintln!(
134
            "Warning: OTEL_EXPORT_ENABLED detected. Logging initialization deferred until runtime is available. Early logs may be dropped."
135
136
137
138
139
        );
    } else {
        rs::logging::init();
    }

Yan Ru Pei's avatar
Yan Ru Pei committed
140
    m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?;
141
    m.add_function(wrap_pyfunction!(lora_name_to_id, m)?)?;
Ryan Olson's avatar
Ryan Olson committed
142
    m.add_function(wrap_pyfunction!(log_message, m)?)?;
143
    m.add_function(wrap_pyfunction!(register_llm, m)?)?;
144
    m.add_function(wrap_pyfunction!(fetch_llm, m)?)?;
145
146
    m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
    m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
147

148
149
150
151
152
153
154
    m.add_class::<DistributedRuntime>()?;
    m.add_class::<CancellationToken>()?;
    m.add_class::<Namespace>()?;
    m.add_class::<Component>()?;
    m.add_class::<Endpoint>()?;
    m.add_class::<Client>()?;
    m.add_class::<AsyncResponseStream>()?;
155
156
157
    m.add_class::<llm::entrypoint::EntrypointArgs>()?;
    m.add_class::<llm::entrypoint::EngineConfig>()?;
    m.add_class::<llm::entrypoint::EngineType>()?;
158
159
    m.add_class::<llm::entrypoint::RouterConfig>()?;
    m.add_class::<llm::entrypoint::KvRouterConfig>()?;
160
    m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
161
    m.add_class::<llm::model_card::ModelDeploymentCard>()?;
162
    m.add_class::<llm::local_model::ModelRuntimeConfig>()?;
163
    m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
164
165
    m.add_class::<llm::preprocessor::MediaDecoder>()?;
    m.add_class::<llm::preprocessor::MediaFetcher>()?;
166
    m.add_class::<llm::backend::Backend>()?;
167
168
    m.add_class::<llm::kv::OverlapScores>()?;
    m.add_class::<llm::kv::KvIndexer>()?;
169
    m.add_class::<llm::kv::ApproxKvIndexer>()?;
170
    m.add_class::<llm::kv::KvEventPublisher>()?;
Yan Ru Pei's avatar
Yan Ru Pei committed
171
172
    m.add_class::<llm::kv::RadixTree>()?;
    m.add_class::<llm::kv::ZmqKvEventListener>()?;
173
174
    m.add_class::<llm::kv::ZmqKvEventPublisher>()?;
    m.add_class::<llm::kv::ZmqKvEventPublisherConfig>()?;
175
    m.add_class::<llm::kv::KvRecorder>()?;
176
177
    m.add_class::<http::HttpService>()?;
    m.add_class::<http::HttpAsyncEngine>()?;
178
    m.add_class::<context::Context>()?;
179
    m.add_class::<ModelType>()?;
180
    m.add_class::<ModelInput>()?;
181
182
183
184
    m.add_class::<llm::kv::ForwardPassMetrics>()?;
    m.add_class::<llm::kv::WorkerStats>()?;
    m.add_class::<llm::kv::KvStats>()?;
    m.add_class::<llm::kv::SpecDecodeStats>()?;
185
186
    m.add_class::<llm::kv::KvPushRouter>()?;
    m.add_class::<llm::kv::KvPushRouterStream>()?;
187
    m.add_class::<RouterMode>()?;
188
    m.add_class::<kserve_grpc::KserveGrpcService>()?;
189
    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
190
191
192
    m.add_class::<planner::VirtualConnectorCoordinator>()?;
    m.add_class::<planner::VirtualConnectorClient>()?;
    m.add_class::<planner::PlannerDecision>()?;
193
194

    engine::add_to_module(m)?;
195
    parsers::add_to_module(m)?;
196

197
    m.add_class::<prometheus_metrics::RuntimeMetrics>()?;
198
199
200
201
    let prometheus_metrics = PyModule::new(m.py(), "prometheus_metrics")?;
    prometheus_metrics::add_to_module(&prometheus_metrics)?;
    m.add_submodule(&prometheus_metrics)?;

202
203
204
205
206
207
208
209
210
211
    Ok(())
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}

Ryan Olson's avatar
Ryan Olson committed
212
213
214
215
216
217
218
/// Log a message from Python with file and line info
#[pyfunction]
#[pyo3(text_signature = "(level, message, module, file, line)")]
fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) {
    logging::log_message(level, message, module, file, line);
}

219
220
221
222
223
224
225
/// Generate a deterministic signed int32 ID from a LoRA name using blake3 hash.
#[pyfunction]
#[pyo3(text_signature = "(lora_name)")]
fn lora_name_to_id(lora_name: &str) -> i32 {
    llm_rs::utils::lora_name_to_id(lora_name)
}

226
227
/// Create an engine and attach it to an endpoint to make it visible to the frontend.
/// This is the main way you create a Dynamo worker / backend.
228
#[pyfunction]
229
#[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None, media_decoder=None, media_fetcher=None))]
230
#[allow(clippy::too_many_arguments)]
231
232
fn register_llm<'p>(
    py: Python<'p>,
233
    model_input: ModelInput,
234
    model_type: ModelType,
235
236
237
    endpoint: Endpoint,
    model_path: &str,
    model_name: Option<&str>,
238
239
    context_length: Option<u32>,
    kv_cache_block_size: Option<u32>,
240
    router_mode: Option<RouterMode>,
241
    migration_limit: u32,
242
    runtime_config: Option<ModelRuntimeConfig>,
243
    user_data: Option<&Bound<'p, PyDict>>,
244
    custom_template_path: Option<&str>,
245
246
    media_decoder: Option<MediaDecoder>,
    media_fetcher: Option<MediaFetcher>,
247
) -> PyResult<Bound<'p, PyAny>> {
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    // Validate Prefill model type requirements
    if model_type.inner == llm_rs::model_type::ModelType::Prefill {
        if !matches!(model_input, ModelInput::Tokens) {
            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
                "ModelType::Prefill requires model_input to be ModelInput::Tokens",
            ));
        }
        if migration_limit != 0 {
            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
                "ModelType::Prefill requires migration_limit to be 0",
            ));
        }
    }

262
263
264
    let model_input = match model_input {
        ModelInput::Text => llm_rs::model_type::ModelInput::Text,
        ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
265
        ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor,
266
267
    };

268
269
    let model_type_obj = model_type.inner;

270
    let inner_path = model_path.to_string();
271
    let mut model_name = model_name.map(|n| n.to_string());
272
273
274
    let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
    let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());

275
276
277
278
279
280
281
282
283
284
285
286
287
    // Early validation of custom template path
    let custom_template_path_owned = custom_template_path
        .map(|s| {
            let path = PathBuf::from(s);
            if !path.exists() {
                return Err(PyErr::new::<pyo3::exceptions::PyFileNotFoundError, _>(
                    format!("Custom template file does not exist: {}", path.display()),
                ));
            }
            Ok(path)
        })
        .transpose()?;

288
289
290
291
292
293
294
    let user_data_json = user_data
        .map(|dict| pythonize::depythonize(dict))
        .transpose()
        .map_err(|err| {
            PyErr::new::<PyException, _>(format!("Failed to convert user_data: {}", err))
        })?;

295
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
296
297
298
299
300
301
302
303
304
305
306
307
308
        let model_path = if fs::exists(&inner_path)? {
            PathBuf::from(inner_path)
        } else {
            // Preserve the model name
            if model_name.is_none() {
                model_name = Some(inner_path.clone());
            }
            // Likely it's a Hugging Face repo, download it
            LocalModel::fetch(&inner_path, false)
                .await
                .map_err(to_pyerr)?
        };

309
310
        let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
        builder
311
            .model_path(model_path)
312
313
            .model_name(model_name)
            .context_length(context_length)
314
            .kv_cache_block_size(kv_cache_block_size)
315
            .router_config(Some(router_config))
316
            .migration_limit(Some(migration_limit))
317
            .runtime_config(runtime_config.unwrap_or_default().inner)
318
            .user_data(user_data_json)
319
320
321
            .custom_template_path(custom_template_path_owned)
            .media_decoder(media_decoder.map(|m| m.inner))
            .media_fetcher(media_fetcher.map(|m| m.inner));
322
        // Load the ModelDeploymentCard
323
        let mut local_model = builder.build().await.map_err(to_pyerr)?;
324
        // Advertise ourself so ingress can find us
325
        local_model
326
            .attach(&endpoint.inner, model_type_obj, model_input)
327
328
329
330
331
332
333
            .await
            .map_err(to_pyerr)?;

        Ok(())
    })
}

334
335
336
337
338
339
340
341
342
343
344
/// Download a model from Hugging Face, returning it's local path
/// Example: `model_path = await fetch_llm("Qwen/Qwen3-0.6B")`
#[pyfunction]
#[pyo3(signature = (remote_name))]
fn fetch_llm<'p>(py: Python<'p>, remote_name: &str) -> PyResult<Bound<'p, PyAny>> {
    let repo = remote_name.to_string();
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        LocalModel::fetch(&repo, false).await.map_err(to_pyerr)
    })
}

345
346
#[pyclass]
#[derive(Clone)]
Ryan Olson's avatar
Ryan Olson committed
347
pub struct DistributedRuntime {
348
349
350
351
    inner: rs::DistributedRuntime,
    event_loop: PyObject,
}

Ryan Olson's avatar
Ryan Olson committed
352
353
impl DistributedRuntime {
    #[allow(dead_code)]
354
    pub(crate) fn inner(&self) -> &rs::DistributedRuntime {
Ryan Olson's avatar
Ryan Olson committed
355
356
357
358
        &self.inner
    }
}

359
#[pyclass]
360
#[derive(Clone)]
361
362
363
364
365
struct CancellationToken {
    inner: rs::CancellationToken,
}

#[pyclass]
366
#[derive(Clone)]
367
368
369
370
371
372
struct Namespace {
    inner: rs::component::Namespace,
    event_loop: PyObject,
}

#[pyclass]
373
#[derive(Clone)]
374
375
376
377
378
379
struct Component {
    inner: rs::component::Component,
    event_loop: PyObject,
}

#[pyclass]
380
#[derive(Clone)]
381
382
383
384
385
386
struct Endpoint {
    inner: rs::component::Endpoint,
    event_loop: PyObject,
}

#[pyclass]
387
#[derive(Clone)]
388
struct Client {
389
    router: rs::pipeline::PushRouter<serde_json::Value, RsAnnotated<serde_json::Value>>,
390
391
}

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
#[pyclass]
#[derive(Clone, PartialEq)]
struct ModelType {
    inner: llm_rs::model_type::ModelType,
}

#[pymethods]
#[allow(non_upper_case_globals)]
impl ModelType {
    #[classattr]
    const Chat: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Chat,
    };
    #[classattr]
    const Completions: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Completions,
    };
    #[classattr]
    const Embedding: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Embedding,
    };
413
414
415
416
    #[classattr]
    const TensorBased: Self = ModelType {
        inner: llm_rs::model_type::ModelType::TensorBased,
    };
417
418
419
420
    #[classattr]
    const Prefill: Self = ModelType {
        inner: llm_rs::model_type::ModelType::Prefill,
    };
421
422
423
424
425
426
427
428
429
430
431
432

    fn __or__(&self, other: &Self) -> Self {
        ModelType {
            inner: self.inner | other.inner,
        }
    }

    fn __str__(&self) -> String {
        self.inner.to_string()
    }
}

433
434
#[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq)]
435
436
437
enum ModelInput {
    Text = 1,
    Tokens = 2,
438
    Tensor = 3,
439
440
}

441
442
443
#[pymethods]
impl DistributedRuntime {
    #[new]
444
    fn new(event_loop: PyObject, store_kv: String, request_plane: String) -> PyResult<Self> {
445
        let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?;
446
        let request_plane: RequestPlaneMode = request_plane.parse().map_err(to_pyerr)?;
447

448
449
450
        // Try to get existing runtime first, create new Worker only if needed
        // This allows multiple DistributedRuntime instances to share the same tokio runtime
        let runtime = rs::Worker::runtime_from_existing()
451
            .or_else(|_| -> anyhow::Result<rs::Runtime> {
452
453
454
455
                // No existing Worker, create new one
                let worker = rs::Worker::from_settings()?;

                // Initialize pyo3 bridge (only happens once per process)
456
                INIT.get_or_try_init(|| -> anyhow::Result<()> {
457
458
                    let primary = worker.tokio_runtime()?;
                    pyo3_async_runtimes::tokio::init_with_runtime(primary).map_err(|e| {
459
                        anyhow::anyhow!("failed to initialize pyo3 static runtime: {:?}", e)
460
                    })?;
461
                    Ok(())
462
463
                })?;

464
                Ok(worker.runtime().clone())
465
466
            })
            .map_err(to_pyerr)?;
467

468
469
        // Initialize logging in context where tokio runtime is available
        // otel exporter requires it
470
471
472
473
        if std::env::var(env_otlp::OTEL_EXPORT_ENABLED)
            .map(|v| v == "1")
            .unwrap_or(false)
        {
474
475
476
477
            runtime.secondary().block_on(async {
                rs::logging::init();
            });
        }
478

479
480
481
        let runtime_config = DistributedConfig {
            store_backend: selected_kv_store,
            nats_config: dynamo_runtime::transports::nats::ClientOptions::default(),
482
            request_plane,
483
484
485
486
487
        };
        let inner = runtime
            .secondary()
            .block_on(rs::DistributedRuntime::new(runtime, runtime_config))
            .map_err(to_pyerr)?;
488
489
490
491

        Ok(DistributedRuntime { inner, event_loop })
    }

Ryan Olson's avatar
Ryan Olson committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    #[staticmethod]
    fn detached(py: Python) -> PyResult<Self> {
        let rt = rs::Worker::runtime_from_existing().map_err(to_pyerr)?;
        let handle = rt.primary();

        let inner = handle
            .block_on(rs::DistributedRuntime::from_settings(rt))
            .map_err(to_pyerr)?;

        Ok(DistributedRuntime {
            inner,
            event_loop: py.None(),
        })
    }

507
508
509
510
511
512
513
514
    fn namespace(&self, name: String) -> PyResult<Namespace> {
        Ok(Namespace {
            inner: self.inner.namespace(name).map_err(to_pyerr)?,
            event_loop: self.event_loop.clone(),
        })
    }

    fn shutdown(&self) {
515
        self.inner.shutdown();
516
517
518
519
520
    }

    fn event_loop(&self) -> PyObject {
        self.event_loop.clone()
    }
521
522
523
524
525

    fn child_token(&self) -> CancellationToken {
        let inner = self.inner.runtime().child_token();
        CancellationToken { inner }
    }
Richard Huo's avatar
Richard Huo committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540

    // This is used to pass the DistributedRuntime from the dynamo-runtime bindings
    // to the KVBM bindings, since KVBM cannot directly use the struct from this cdylib.
    // TODO: Create a separate crate "dynamo-python" so that all binding crates can import
    // from it and share the same crate path. This will allow PyO3 to automatically
    // recognize that both bindings use the same PyClass.
    #[pyo3(name = "to_capsule")]
    fn to_capsule<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
        let arc: Arc<rs::DistributedRuntime> = Arc::new(self.inner.clone());
        let weak: Weak<rs::DistributedRuntime> = Arc::downgrade(&arc);

        let name = CString::new("dynamo.runtime.weak").expect("valid capsule name");

        PyCapsule::new(py, weak, Some(name))
    }
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
}

#[pymethods]
impl CancellationToken {
    fn cancel(&self) {
        self.inner.cancel();
    }

    fn cancelled<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let token = self.inner.clone();
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            token.cancelled().await;
            Ok(())
        })
    }
}

#[pymethods]
impl Component {
    fn endpoint(&self, name: String) -> PyResult<Endpoint> {
        let inner = self.inner.endpoint(name);
        Ok(Endpoint {
            inner,
            event_loop: self.event_loop.clone(),
        })
    }

568
    /// NATS specific stats/metrics call
569
    fn create_service<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
570
        let mut inner = self.inner.clone();
571
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
572
            inner.add_stats_service().await.map_err(to_pyerr)?;
573
574
575
            Ok(())
        })
    }
576
577
578

    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
579
580
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_component(self.inner.clone())
581
    }
582
583
584
585
}

#[pymethods]
impl Endpoint {
586
    #[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None, health_check_payload = None))]
587
588
589
590
    fn serve_endpoint<'p>(
        &self,
        py: Python<'p>,
        generator: PyObject,
591
        graceful_shutdown: Option<bool>,
592
        metrics_labels: Option<Vec<(String, String)>>,
593
        health_check_payload: Option<&Bound<'p, PyDict>>,
594
595
596
597
598
599
    ) -> PyResult<Bound<'p, PyAny>> {
        let engine = Arc::new(engine::PythonAsyncEngine::new(
            generator,
            self.event_loop.clone(),
        )?);
        let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
600
601
602
603
604
605
606
607
608
609
610
611
612

        // Convert Python dict to serde_json::Value if provided and validate it's an object
        let health_payload_json = health_check_payload
            .map(|dict| pythonize::depythonize::<serde_json::Value>(dict))
            .transpose()
            .map_err(|err| {
                pyo3::exceptions::PyTypeError::new_err(format!(
                    "Failed to convert health_check_payload: {}",
                    err
                ))
            })?;

        // Require an object/dict
613
614
615
616
617
618
        if let Some(ref payload) = health_payload_json
            && !payload.is_object()
        {
            return Err(pyo3::exceptions::PyTypeError::new_err(
                "health_check_payload must be a JSON object (dict)",
            ));
619
620
621
        }

        let mut builder = self
622
623
624
625
            .inner
            .endpoint_builder()
            .metrics_labels(metrics_labels)
            .handler(ingress);
626
627
628
629
630

        if let Some(payload) = health_payload_json {
            builder = builder.health_check_payload(payload);
        }

631
        let graceful_shutdown = graceful_shutdown.unwrap_or(true);
632
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
633
634
635
636
637
            builder
                .graceful_shutdown(graceful_shutdown)
                .start()
                .await
                .map_err(to_pyerr)?;
638
639
640
641
            Ok(())
        })
    }

642
643
644
    fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let inner = self.inner.clone();
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
645
            let client = inner.client().await.map_err(to_pyerr)?;
646
647
648
649
650
651
            let push_router = rs::pipeline::PushRouter::<
                serde_json::Value,
                RsAnnotated<serde_json::Value>,
            >::from_client(client, Default::default())
            .await
            .map_err(to_pyerr)?;
652
653
654
            Ok(Client {
                router: push_router,
            })
655
656
        })
    }
657

658
659
660
    // Opaque unique ID for this worker. May change over worker lifetime.
    fn connection_id(&self) -> u64 {
        self.inner.drt().connection_id()
661
    }
662
663
664

    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
665
666
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_endpoint(self.inner.clone())
667
    }
668
669
670
671
672
673
674
675
676
677
678
}

#[pymethods]
impl Namespace {
    fn component(&self, name: String) -> PyResult<Component> {
        let inner = self.inner.component(name).map_err(to_pyerr)?;
        Ok(Component {
            inner,
            event_loop: self.event_loop.clone(),
        })
    }
679
680
681

    /// Get a RuntimeMetrics helper for creating Prometheus metrics
    #[getter]
682
683
    fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
        prometheus_metrics::RuntimeMetrics::from_namespace(self.inner.clone())
684
    }
685
686
687
688
}

#[pymethods]
impl Client {
689
690
    /// Get list of current instances.
    /// Replaces endpoint_ids.
691
    fn instance_ids(&self) -> Vec<u64> {
692
        self.router.client.instance_ids()
693
694
    }

695
696
697
    /// Wait for an instance to be available for work.
    /// Replaces wait_for_endpoints.
    fn wait_for_instances<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
698
        let inner = self.router.client.clone();
699
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
700
            inner
701
                .wait_for_instances()
702
                .await
703
                .map(|v| v.into_iter().map(|cei| cei.id()).collect::<Vec<u64>>())
704
                .map_err(to_pyerr)
705
706
707
708
        })
    }

    /// Issue a request to the endpoint using the default routing strategy.
709
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
710
711
712
713
714
    fn generate<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
715
        context: Option<context::Context>,
716
    ) -> PyResult<Bound<'p, PyAny>> {
717
        self.random(py, request, annotated, context)
718
719
720
    }

    /// Send a request to the next endpoint in a round-robin fashion.
721
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
722
723
724
725
726
    fn round_robin<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
727
        context: Option<context::Context>,
728
729
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
730
        let request_ctx = create_request_context(request, &context);
731
732
733
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
734
        let client = self.router.clone();
735
736

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
737
738
            let stream = match context {
                Some(context) => {
739
740
                    // Always instrument with appropriate span (none if no trace context)
                    let span = get_span_for_context(&context, "round_robin");
741
742
743
744
745
                    client
                        .round_robin(request_ctx)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
746
                }
747
                _ => client.round_robin(request_ctx).await.map_err(to_pyerr)?,
748
            };
749
750
751
752
753
754
755
756
757
            tokio::spawn(process_stream(stream, tx));
            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }

    /// Send a request to a random endpoint.
758
    #[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
759
760
761
762
763
    fn random<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
        annotated: Option<bool>,
764
        context: Option<context::Context>,
765
766
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
767
        let request_ctx = create_request_context(request, &context);
768
769
770
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
771
        let client = self.router.clone();
772
773

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
774
775
            let stream = match context {
                Some(context) => {
776
                    // Always instrument with appropriate span (none if no trace context)
777
                    let span = get_span_for_context(&context, "random");
778
779
780
781
782
                    client
                        .random(request_ctx)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
783
                }
784
                _ => client.random(request_ctx).await.map_err(to_pyerr)?,
785
            };
786
787
788
789
790
791
792
793
794
            tokio::spawn(process_stream(stream, tx));
            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }

    /// Directly send a request to a specific endpoint.
795
    #[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
796
797
798
799
    fn direct<'p>(
        &self,
        py: Python<'p>,
        request: PyObject,
800
        instance_id: u64,
801
        annotated: Option<bool>,
802
        context: Option<context::Context>,
803
804
    ) -> PyResult<Bound<'p, PyAny>> {
        let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
805
        let request_ctx = create_request_context(request, &context);
806
807
808
        let annotated = annotated.unwrap_or(false);

        let (tx, rx) = tokio::sync::mpsc::channel(32);
809
        let client = self.router.clone();
810
811

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
812
813
            let stream = match context {
                Some(context) => {
814
                    // Always instrument with appropriate span (none if no trace context)
815
816
                    let span =
                        get_span_for_direct_context(&context, "direct", &instance_id.to_string());
817
818
819
820
821
                    client
                        .direct(request_ctx, instance_id)
                        .instrument(span)
                        .await
                        .map_err(to_pyerr)?
822
823
                }
                _ => client
824
                    .direct(request_ctx, instance_id)
825
826
827
                    .await
                    .map_err(to_pyerr)?,
            };
828
829
830
831
832
833
834
835
836
837
838
839

            tokio::spawn(process_stream(stream, tx));

            Ok(AsyncResponseStream {
                rx: Arc::new(Mutex::new(rx)),
                annotated,
            })
        })
    }
}

async fn process_stream(
840
    stream: EngineStream<RsAnnotated<serde_json::Value>>,
841
842
843
844
845
    tx: tokio::sync::mpsc::Sender<RsAnnotated<PyObject>>,
) {
    let mut stream = stream;
    while let Some(response) = stream.next().await {
        // Convert the response to a PyObject using Python's GIL
846
        let annotated: RsAnnotated<serde_json::Value> = response;
847
        let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| {
848
            Python::with_gil(|py| match pythonize::pythonize(py, &data) {
849
850
                Ok(pyobj) => Ok(pyobj.into()),
                Err(e) => Err(e.to_string()),
851
            })
852
853
854
855
856
857
        });

        let is_error = annotated.is_error();

        // Send the PyObject through the channel or log an error
        if let Err(e) = tx.send(annotated).await {
858
            tracing::error!("Failed to send response: {:?}", e);
859
            break;
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
        }

        if is_error {
            break;
        }
    }
}

#[pyclass]
struct AsyncResponseStream {
    rx: Arc<Mutex<tokio::sync::mpsc::Receiver<RsAnnotated<PyObject>>>>,
    annotated: bool,
}

#[pymethods]
impl AsyncResponseStream {
    /// This method is required to implement the `AsyncIterator` protocol.
    #[pyo3(name = "__aiter__")]
    fn aiter(slf: PyRef<Self>, py: Python) -> PyResult<Py<PyAny>> {
        slf.into_py_any(py)
    }
    /// This method is required to implement the `AsyncIterator` protocol.
    #[pyo3(name = "__anext__")]
    fn next<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
        let rx = self.rx.clone();
        let annotated = self.annotated;

        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            loop {
                let value = rx.lock().await.recv().await;
                match value {
                    Some(pyobj) => {
                        let pyobj = match pyobj.ok() {
                            Ok(pyobj) => pyobj,
                            Err(e) => {
                                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(e));
                            }
                        };

                        if annotated {
                            let object = Annotated { inner: pyobj };
                            #[allow(deprecated)]
                            let object = Python::with_gil(|py| object.into_py(py));
                            return Ok(object);
                        } else {
                            match pyobj.data {
                                Some(data) => return Ok(data),
                                None => continue,
                            }
                        }
                    }
                    None => return Err(PyStopAsyncIteration::new_err("Stream exhausted")),
                }
            }
        })
    }
}

#[pyclass]
struct Annotated {
    inner: RsAnnotated<PyObject>,
}

#[pymethods]
impl Annotated {
    #[new]
    fn new(data: PyObject) -> Self {
        Annotated {
            inner: RsAnnotated::from_data(data),
        }
    }

    fn is_error(&self) -> bool {
        self.inner.is_error()
    }

    fn data(&self) -> Option<PyObject> {
        self.inner.data.clone()
    }

    fn event(&self) -> Option<String> {
        self.inner.event.clone()
    }

    fn comments(&self) -> Option<Vec<String>> {
        self.inner.comment.clone()
    }

    fn id(&self) -> Option<String> {
        self.inner.id.clone()
    }

    #[pyo3(name = "__repr__")]
    fn _repr(&self, py: Python) -> String {
        let data = self.inner.data.clone().map(|obj| {
            obj.call_method0(py, "__repr__")
                .and_then(|repr_obj| repr_obj.extract::<Py<PyString>>(py))
                .map(|py_str| py_str.to_string_lossy(py).into_owned())
                .unwrap_or_else(|_| "<failed_repr>".to_string())
        });

        format!(
            "Annotated(data={}, event={}, comment={:?}, id={})",
            data.unwrap_or_else(|| "<no_data>".to_string()),
            self.inner.event.as_deref().unwrap_or("None"),
            self.inner.comment.as_deref().unwrap_or(&[]),
            self.inner.id.as_deref().unwrap_or("None")
        )
    }
}