Commit 1b96c2c4 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Bring-your-own engine for dynemo-run (#43)

1. Create `my_engine.py`

```
import asyncio

async def generate(request):
    yield {"id":"1","choices":[{"index":0,"delta":{"content":"The","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
    await asyncio.sleep(0.1)
    yield {"id":"1","choices":[{"index":0,"delta":{"content":" capital","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
    await asyncio.sleep(0.1)
    yield {"id":"1","choices":[{"index":0,"delta":{"content":" of","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
    await asyncio.sleep(0.1)
    yield {"id":"1","choices":[{"index":0,"delta":{"content":" France","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
    await asyncio.sleep(0.1)
    yield {"id":"1","choices":[{"index":0,"delta":{"content":" is","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
    await asyncio.sleep(0.1)
    yield {"id":"1","choices":[{"index":0,"delta":{"content":" Paris","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
    await asyncio.sleep(0.1)
    yield {"id":"1","choices":[{"index":0,"delta":{"content":".","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
    await asyncio.sleep(0.1)
    yield {"id":"1","choices":[{"index":0,"delta":{"content":"","role":"assistant"},"finish_reason":"stop"}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
```

2. Build

```
cargo build --release --feature python
```

3. Run

```
dynemo-run out=pystr:my_engine.py --name test
```

And here's a distributed system, with your engine:

- Node 1: `dynemo-run in=http out=dyn://test`
- Node 2: `dynemo-run in=dyn://test out=pystr:my_engine.py`
parent 3c60fe2a
......@@ -160,6 +160,18 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0f477b951e452a0b6b4a10b53ccd569042d1d01729b519e02074a9c0958a063"
[[package]]
name = "async-channel"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a"
dependencies = [
"concurrent-queue",
"event-listener-strategy",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-nats"
version = "0.38.0"
......@@ -868,6 +880,15 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]]
name = "concurrent-queue"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "console"
version = "0.15.10"
......@@ -1405,6 +1426,8 @@ dependencies = [
"mistralrs",
"prometheus",
"pyo3",
"pyo3-async-runtimes",
"pythonize",
"regex",
"semver",
"serde",
......@@ -1662,6 +1685,27 @@ dependencies = [
"tower-service",
]
[[package]]
name = "event-listener"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]]
name = "event-listener-strategy"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2"
dependencies = [
"event-listener",
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
......@@ -2620,6 +2664,15 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "inventory"
version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab08d7cd2c5897f2c949e5383ea7c7db03fb19130ffcfbf7eda795137ae3cb83"
dependencies = [
"rustversion",
]
[[package]]
name = "iovec"
version = "0.1.4"
......@@ -3679,6 +3732,12 @@ dependencies = [
"serde",
]
[[package]]
name = "parking"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
[[package]]
name = "parking_lot"
version = "0.12.3"
......@@ -4004,6 +4063,34 @@ dependencies = [
"unindent",
]
[[package]]
name = "pyo3-async-runtimes"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "977dc837525cfd22919ba6a831413854beb7c99a256c03bf8624ad707e45810e"
dependencies = [
"async-channel",
"clap",
"futures",
"inventory",
"once_cell",
"pin-project-lite",
"pyo3",
"pyo3-async-runtimes-macros",
"tokio",
]
[[package]]
name = "pyo3-async-runtimes-macros"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2df2884957d2476731f987673befac5d521dff10abb0a7cbe12015bc7702fe9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
......@@ -4049,6 +4136,16 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "pythonize"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91a6ee7a084f913f98d70cdc3ebec07e852b735ae3059a1500db2661265da9ff"
dependencies = [
"pyo3",
"serde",
]
[[package]]
name = "qoi"
version = "0.4.1"
......
......@@ -27,6 +27,7 @@ sglang = ["dynemo-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"]
vllm = ["dynemo-llm/vllm", "dep:netlink-packet-route", "dep:rtnetlink"]
llamacpp = ["dynemo-llm/llamacpp"]
trtllm = ["dynemo-llm/trtllm"]
python = ["dynemo-llm/python"]
cuda = ["dynemo-llm/cuda"]
metal = ["dynemo-llm/metal"]
......
......@@ -156,6 +156,43 @@ Node 2:
dynemo-run in=none out=vllm ~/llm_models/Llama-3.2-3B-Instruct/ --num-nodes 2 --leader-addr 10.217.98.122:6539 --node-rank 1
```
## python
You can provide your own engine in a Python file. The file must provide a generator with this signature:
```
async def generate(request):
```
- The `request` parameter is a map, an OpenAI compatible create chat completion request: https://platform.openai.com/docs/api-reference/chat/create
- The function must `yield` a series of maps conforming to create chat completion stream response (example below).
The file is loaded once at startup and kept in memory.
- Build: `cargo build --release --features python`
- Run: `dynemo-run out=pystr:/home/user/my_python_engine.py --name <model-name>`
Example engine:
```
import asyncio
async def generate(request):
yield {"id":"1","choices":[{"index":0,"delta":{"content":"The","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
await asyncio.sleep(0.1)
yield {"id":"1","choices":[{"index":0,"delta":{"content":" capital","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
await asyncio.sleep(0.1)
yield {"id":"1","choices":[{"index":0,"delta":{"content":" of","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
await asyncio.sleep(0.1)
yield {"id":"1","choices":[{"index":0,"delta":{"content":" France","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
await asyncio.sleep(0.1)
yield {"id":"1","choices":[{"index":0,"delta":{"content":" is","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
await asyncio.sleep(0.1)
yield {"id":"1","choices":[{"index":0,"delta":{"content":" Paris","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
await asyncio.sleep(0.1)
yield {"id":"1","choices":[{"index":0,"delta":{"content":".","role":"assistant"}}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
await asyncio.sleep(0.1)
yield {"id":"1","choices":[{"index":0,"delta":{"content":"","role":"assistant"},"finish_reason":"stop"}],"created":1841762283,"model":"Llama-3.2-1B-Instruct","system_fingerprint":"local","object":"chat.completion.chunk"}
```
## trtllm
......
......@@ -43,6 +43,10 @@ pub use opt::{Input, Output};
/// concatenations.
const ENDPOINT_SCHEME: &str = "dyn://";
/// How we identify a python string endpoint
#[cfg(feature = "python")]
const PYTHON_STR_SCHEME: &str = "pystr:";
pub enum EngineConfig {
/// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later.
......@@ -334,6 +338,19 @@ pub async fn run(
card: Box::new(card),
}
}
#[cfg(feature = "python")]
Output::PythonStr(path_str) => {
use dynemo_llm::engines::python;
let Some(model_name) = model_name else {
anyhow::bail!("Provide model service name as `--model-name <this>`");
};
let p = std::path::PathBuf::from(path_str);
let engine = python::make_string_engine(&p).await?;
EngineConfig::StaticFull {
service_name: model_name,
engine,
}
}
};
match in_opt {
......
......@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull;
const ZMQ_SOCKET_PREFIX: &str = "dyn";
const USAGE: &str = "USAGE: dynemo-run in=[http|text|dyn://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|trtllm|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0]";
const USAGE: &str = "USAGE: dynemo-run in=[http|text|dyn://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|trtllm|echo_full|echo_core|pystr:<engine.py>] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> {
logging::init();
......
......@@ -91,6 +91,11 @@ pub enum Output {
#[cfg(feature = "trtllm")]
/// Run inference using trtllm
TrtLLM,
/// Run inference using a user supplied python file that accepts and return
/// strings (meaning it does it's own pre-processing).
#[cfg(feature = "python")]
PythonStr(String),
}
impl TryFrom<&str> for Output {
......@@ -121,6 +126,14 @@ impl TryFrom<&str> for Output {
Ok(Output::Endpoint(path.to_string()))
}
#[cfg(feature = "python")]
python_str_gen if python_str_gen.starts_with(crate::PYTHON_STR_SCHEME) => {
let path = python_str_gen
.strip_prefix(crate::PYTHON_STR_SCHEME)
.unwrap();
Ok(Output::PythonStr(path.to_string()))
}
e => Err(anyhow::anyhow!("Invalid out= option '{e}'")),
}
}
......@@ -148,6 +161,9 @@ impl fmt::Display for Output {
Output::EchoCore => "echo_core",
Output::Endpoint(path) => path,
#[cfg(feature = "python")]
Output::PythonStr(path) => path,
};
write!(f, "{s}")
}
......
......@@ -160,6 +160,18 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0f477b951e452a0b6b4a10b53ccd569042d1d01729b519e02074a9c0958a063"
[[package]]
name = "async-channel"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a"
dependencies = [
"concurrent-queue",
"event-listener-strategy",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-nats"
version = "0.38.0"
......@@ -893,6 +905,15 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]]
name = "concurrent-queue"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "console"
version = "0.15.10"
......@@ -1421,6 +1442,8 @@ dependencies = [
"prometheus",
"proptest",
"pyo3",
"pyo3-async-runtimes",
"pythonize",
"regex",
"reqwest",
"rstest",
......@@ -1657,6 +1680,27 @@ dependencies = [
"tower-service",
]
[[package]]
name = "event-listener"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]]
name = "event-listener-strategy"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2"
dependencies = [
"event-listener",
"pin-project-lite",
]
[[package]]
name = "eventsource-stream"
version = "0.2.3"
......@@ -2670,6 +2714,15 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "inventory"
version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab08d7cd2c5897f2c949e5383ea7c7db03fb19130ffcfbf7eda795137ae3cb83"
dependencies = [
"rustversion",
]
[[package]]
name = "iovec"
version = "0.1.4"
......@@ -3680,6 +3733,12 @@ dependencies = [
"serde",
]
[[package]]
name = "parking"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
[[package]]
name = "parking_lot"
version = "0.12.3"
......@@ -4093,6 +4152,34 @@ dependencies = [
"unindent",
]
[[package]]
name = "pyo3-async-runtimes"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "977dc837525cfd22919ba6a831413854beb7c99a256c03bf8624ad707e45810e"
dependencies = [
"async-channel",
"clap",
"futures",
"inventory",
"once_cell",
"pin-project-lite",
"pyo3",
"pyo3-async-runtimes-macros",
"tokio",
]
[[package]]
name = "pyo3-async-runtimes-macros"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2df2884957d2476731f987673befac5d521dff10abb0a7cbe12015bc7702fe9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
......@@ -4138,6 +4225,16 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "pythonize"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91a6ee7a084f913f98d70cdc3ebec07e852b735ae3059a1500db2661265da9ff"
dependencies = [
"pyo3",
"serde",
]
[[package]]
name = "qoi"
version = "0.4.1"
......
......@@ -35,6 +35,7 @@ llamacpp = ["dep:llama-cpp-2"]
sglang = ["dep:async_zmq"]
sentencepiece = ["dep:sentencepiece"]
vllm = ["dep:async_zmq"]
python = ["dep:pyo3-async-runtimes", "dep:pythonize"]
trtllm = []
cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"]
......@@ -89,6 +90,12 @@ strum = { workspace = true }
async-openai = "0.27.2"
blake3 = "1"
regex = "1"
pyo3 = { version = "0.23.3", default-features = false, features = [
"macros",
"experimental-async",
"experimental-inspect",
"py-clone",
] }
# protocols
chrono = { version = "0.4", default-features = false, features = [
......@@ -113,12 +120,6 @@ mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e6
# sglang
async_zmq = { version = "0.4.0", optional = true }
libc = "0.2"
pyo3 = { version = "0.23.3", default-features = false, features = [
"macros",
"experimental-async",
"experimental-inspect",
"py-clone",
] }
serde-pickle = "1.2.0"
# llamacpp
......@@ -149,6 +150,16 @@ semver = { version = "1", features = ["serde"] }
# trtllm
serde_repr = "0.1"
# python
pyo3-async-runtimes = { version = "0.23.0", optional = true, default-features = false, features = [
"attributes",
"testing",
"tokio-runtime",
"unstable-streams",
] }
pythonize = { version = "0.23", optional = true }
[dev-dependencies]
proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
......
......@@ -28,6 +28,9 @@ pub mod vllm;
#[cfg(feature = "trtllm")]
pub mod trtllm;
#[cfg(feature = "python")]
pub mod python;
#[derive(Debug, Clone)]
pub struct MultiNodeConfig {
/// How many nodes / hosts we are using
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ffi::CStr;
use std::{path::Path, sync::Arc};
use dynemo_runtime::pipeline::error as pipeline_error;
pub use dynemo_runtime::{
error,
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Data, ManyOut, ResponseStream,
SingleIn,
},
protocols::annotated::Annotated,
Error, Result,
};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict};
use pyo3_async_runtimes::TaskLocals;
use pythonize::{depythonize, pythonize};
pub use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio::sync::oneshot::Sender;
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
/// Python snippet to import a file as a module
const PY_IMPORT: &CStr = cr#"
import importlib.util
import sys
module_name = file_path.split("/")[-1].replace(".py", "")
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
"#;
/// An engine that takes and returns strings, feeding them to a python written engine
pub async fn make_string_engine(
py_file: &Path,
) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> {
pyo3::prepare_freethreaded_python();
let engine = PythonStringEngine::new(py_file).await?;
let engine: OpenAIChatCompletionsStreamingEngine = Arc::new(engine);
Ok(engine)
}
struct PythonStringEngine {
_user_module: PyObject,
generator: Arc<Py<PyAny>>,
event_loop: Arc<Py<PyAny>>,
}
impl PythonStringEngine {
async fn new(py_file: &Path) -> anyhow::Result<Self> {
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::task::spawn_blocking(move || run_asyncio(tx));
let event_loop = rx.await?;
let user_module = python_file_to_module(py_file)?;
let generator = Python::with_gil(|py| user_module.getattr(py, "generate").unwrap());
Ok(PythonStringEngine {
_user_module: user_module,
generator: Arc::new(generator),
event_loop,
})
}
}
/// Start asyncio event loop and block on it forever
fn run_asyncio(tx: Sender<Arc<PyObject>>) {
let event_loop: PyObject = Python::with_gil(|py| {
let aio: PyObject = py.import("asyncio").unwrap().into();
aio.call_method0(py, "new_event_loop").unwrap()
});
let event_loop = Arc::new(event_loop);
let _ = tx.send(event_loop.clone());
Python::with_gil(|py| {
let _ = event_loop.call_method0(py, "run_forever");
});
}
fn python_file_to_module(p: &Path) -> Result<PyObject> {
let module: PyObject = Python::with_gil(|py| {
let globals = [("file_path", p.display().to_string())]
.into_py_dict(py)
.unwrap();
let locals = PyDict::new(py);
py.run(PY_IMPORT, Some(&globals), Some(&locals)).unwrap();
let module = locals.get_item("module").unwrap().unwrap();
module.extract().unwrap()
});
Ok(module)
}
#[derive(Debug, thiserror::Error)]
enum ResponseProcessingError {
#[error("python exception: {0}")]
PythonException(String),
#[error("deserialize error: {0}")]
DeserializeError(String),
#[error("gil offload error: {0}")]
OffloadError(String),
}
#[async_trait]
impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> for PythonStringEngine
where
Req: Data + Serialize,
Resp: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> {
// Create a context
let (request, context) = request.transfer(());
let ctx = context.context();
let id = context.id().to_string();
tracing::trace!("processing request: {}", id);
// Clone the PyObject to move into the thread
// Create a channel to communicate between the Python thread and the Rust async context
let (tx, rx) = mpsc::channel::<Annotated<Resp>>(128);
let generator = self.generator.clone();
let event_loop = self.event_loop.clone();
// Acquiring the GIL is similar to acquiring a standard lock/mutex
// Performing this in an tokio async task could block the thread for an undefined amount of time
// To avoid this, we spawn a blocking task to acquire the GIL and perform the operations needed
// while holding the GIL.
//
// Under low GIL contention, we wouldn't need to do this.
// However, under high GIL contention, this can lead to significant performance degradation.
//
// Since we cannot predict the GIL contention, we will always use the blocking task and pay the
// cost. The Python GIL is the gift that keeps on giving -- performance hits...
let stream = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let py_request = pythonize(py, &request)?;
let gen = generator.call1(py, (py_request,))?;
let locals = TaskLocals::new(event_loop.bind(py).clone());
pyo3_async_runtimes::tokio::into_stream_with_locals_v1(locals, gen.into_bound(py))
})
})
.await??;
let stream = Box::pin(stream);
// process the stream
// any error thrown in the stream will be caught and complete the processing task
// errors are captured by a task that is watching the processing task
// the error will be emitted as an annotated error
let request_id = id.clone();
tokio::spawn(async move {
tracing::debug!(
request_id,
"starting task to process python async generator stream"
);
let mut stream = stream;
let mut count = 0;
while let Some(item) = stream.next().await {
count += 1;
tracing::trace!(
request_id,
"processing the {}th item from python async generator",
count
);
let mut done = false;
let response = match process_item::<Resp>(item).await {
Ok(response) => response,
Err(e) => {
done = true;
let msg = match &e {
ResponseProcessingError::DeserializeError(e) => {
// tell the python async generator to stop generating
// right now, this is impossible as we are not passing the context to the python async generator
// todo: add task-local context to the python async generator
ctx.stop_generating();
let msg = format!("critical error: invalid response object from python async generator; application-logic-mismatch: {}", e);
tracing::error!(request_id, "{}", msg);
msg
}
ResponseProcessingError::PythonException(e) => {
let msg = format!("a python exception was caught while processing the async generator: {}", e);
tracing::warn!(request_id, "{}", msg);
msg
}
ResponseProcessingError::OffloadError(e) => {
let msg = format!("critical error: failed to offload the python async generator to a new thread: {}", e);
tracing::error!(request_id, "{}", msg);
msg
}
};
Annotated::from_error(msg)
}
};
if tx.send(response).await.is_err() {
tracing::trace!(
request_id,
"error forwarding annotated response to channel; channel is closed"
);
break;
}
if done {
tracing::debug!(
request_id,
"early termination of python async generator stream task"
);
break;
}
}
tracing::debug!(
request_id,
"finished processing python async generator stream"
);
});
let stream = ReceiverStream::new(rx);
Ok(ResponseStream::new(Box::pin(stream), context.context()))
}
}
async fn process_item<Resp>(
item: Result<Py<PyAny>, PyErr>,
) -> Result<Annotated<Resp>, ResponseProcessingError>
where
Resp: Data + for<'de> Deserialize<'de>,
{
let item = item.map_err(|e| ResponseProcessingError::PythonException(e.to_string()))?;
let response = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| depythonize::<Resp>(&item.into_bound(py)))
})
.await
.map_err(|e| ResponseProcessingError::OffloadError(e.to_string()))?
.map_err(|e| ResponseProcessingError::DeserializeError(e.to_string()))?;
let response = Annotated::from_data(response);
Ok(response)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment