Commit 6ca24080 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

fix: Improve PythonAsyncEngine error handling and Increase Tokio thread count (#129)


Signed-off-by: default avatarRyan Olson <ryanolson@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 910751c6
...@@ -26,6 +26,56 @@ dependencies = [ ...@@ -26,6 +26,56 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "anstream"
version = "0.6.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9"
[[package]]
name = "anstyle-parse"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e"
dependencies = [
"anstyle",
"once_cell",
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.95" version = "1.0.95"
...@@ -290,6 +340,12 @@ version = "0.2.1" ...@@ -290,6 +340,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "colorchoice"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]] [[package]]
name = "const-oid" name = "const-oid"
version = "0.9.6" version = "0.9.6"
...@@ -561,6 +617,29 @@ dependencies = [ ...@@ -561,6 +617,29 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "env_filter"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0"
dependencies = [
"log",
"regex",
]
[[package]]
name = "env_logger"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0"
dependencies = [
"anstream",
"anstyle",
"env_filter",
"humantime",
"log",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.1" version = "1.0.1"
...@@ -854,6 +933,12 @@ version = "1.0.3" ...@@ -854,6 +933,12 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "1.6.0" version = "1.6.0"
...@@ -1078,6 +1163,12 @@ version = "0.1.15" ...@@ -1078,6 +1163,12 @@ version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb"
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.13.0" version = "0.13.0"
...@@ -2340,6 +2431,7 @@ dependencies = [ ...@@ -2340,6 +2431,7 @@ dependencies = [
"derive_builder", "derive_builder",
"educe", "educe",
"either", "either",
"env_logger",
"etcd-client", "etcd-client",
"figment", "figment",
"futures", "futures",
...@@ -2430,6 +2522,12 @@ version = "1.0.4" ...@@ -2430,6 +2522,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.12.1" version = "1.12.1"
......
...@@ -25,6 +25,10 @@ homepage = "https://github.com/triton-inference-server/triton_distributed" ...@@ -25,6 +25,10 @@ homepage = "https://github.com/triton-inference-server/triton_distributed"
repository = "https://github.com/triton-inference-server/triton_distributed" repository = "https://github.com/triton-inference-server/triton_distributed"
keywords = ["llm", "genai", "inference", "nvidia", "distributed", "triton"] keywords = ["llm", "genai", "inference", "nvidia", "distributed", "triton"]
[features]
default = []
integration = []
[dependencies] [dependencies]
# workspace - when we expand to multiple crates; put these in the workspace # workspace - when we expand to multiple crates; put these in the workspace
anyhow = { version = "1" } anyhow = { version = "1" }
...@@ -64,3 +68,4 @@ rand = { version = "0.8"} ...@@ -64,3 +68,4 @@ rand = { version = "0.8"}
[dev-dependencies] [dev-dependencies]
assert_matches = "1.5.0" assert_matches = "1.5.0"
env_logger = "0.11"
...@@ -131,7 +131,7 @@ dependencies = [ ...@@ -131,7 +131,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_nanos", "serde_nanos",
"serde_repr", "serde_repr",
"thiserror", "thiserror 1.0.69",
"time", "time",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
...@@ -1271,7 +1271,7 @@ checksum = "3669cf5561f8d27e8fc84cc15e58350e70f557d4d65f70e3154e54cd2f8e1782" ...@@ -1271,7 +1271,7 @@ checksum = "3669cf5561f8d27e8fc84cc15e58350e70f557d4d65f70e3154e54cd2f8e1782"
dependencies = [ dependencies = [
"libc", "libc",
"neli", "neli",
"thiserror", "thiserror 1.0.69",
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
...@@ -1377,7 +1377,7 @@ checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3" ...@@ -1377,7 +1377,7 @@ checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3"
dependencies = [ dependencies = [
"rand", "rand",
"serde", "serde",
"thiserror", "thiserror 1.0.69",
] ]
[[package]] [[package]]
...@@ -1649,7 +1649,7 @@ dependencies = [ ...@@ -1649,7 +1649,7 @@ dependencies = [
"memchr", "memchr",
"parking_lot", "parking_lot",
"protobuf", "protobuf",
"thiserror", "thiserror 1.0.69",
] ]
[[package]] [[package]]
...@@ -2298,7 +2298,16 @@ version = "1.0.69" ...@@ -2298,7 +2298,16 @@ version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl 1.0.69",
]
[[package]]
name = "thiserror"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
dependencies = [
"thiserror-impl 2.0.11",
] ]
[[package]] [[package]]
...@@ -2312,6 +2321,17 @@ dependencies = [ ...@@ -2312,6 +2321,17 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "thiserror-impl"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]] [[package]]
name = "time" name = "time"
version = "0.3.37" version = "0.3.37"
...@@ -2620,7 +2640,7 @@ dependencies = [ ...@@ -2620,7 +2640,7 @@ dependencies = [
"regex", "regex",
"serde", "serde",
"serde_json", "serde_json",
"thiserror", "thiserror 1.0.69",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
...@@ -2641,6 +2661,7 @@ dependencies = [ ...@@ -2641,6 +2661,7 @@ dependencies = [
"pythonize", "pythonize",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.11",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
......
...@@ -36,6 +36,7 @@ futures = "0.3" ...@@ -36,6 +36,7 @@ futures = "0.3"
once_cell = "1.20.3" once_cell = "1.20.3"
serde = "1" serde = "1"
serde_json = "1.0.138" serde_json = "1.0.138"
thiserror = "2.0"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-stream = "0" tokio-stream = "0"
tracing = "0" tracing = "0"
......
...@@ -34,8 +34,8 @@ apt install protobuf-compiler ...@@ -34,8 +34,8 @@ apt install protobuf-compiler
``` ```
3. Setup a virtualenv 3. Setup a virtualenv
``` ```
cd python-wheels/triton-distributed
uv venv uv venv
source .venv/bin/activate source .venv/bin/activate
uv pip install maturin uv pip install maturin
......
...@@ -38,6 +38,18 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -38,6 +38,18 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
Ok(()) Ok(())
} }
#[derive(Debug, thiserror::Error)]
enum ResponseProcessingError {
#[error("python exception: {0}")]
PythonException(String),
#[error("deserialize error: {0}")]
DeserializeError(String),
#[error("gil offload error: {0}")]
OffloadError(String),
}
// todos: // todos:
// - [ ] enable context cancellation // - [ ] enable context cancellation
// - this will likely require a change to the function signature python calling arguments // - this will likely require a change to the function signature python calling arguments
...@@ -109,6 +121,7 @@ where ...@@ -109,6 +121,7 @@ where
async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> { async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> {
// Create a context // Create a context
let (request, context) = request.transfer(()); let (request, context) = request.transfer(());
let ctx = context.context();
let id = context.id().to_string(); let id = context.id().to_string();
log::trace!("processing request: {}", id); log::trace!("processing request: {}", id);
...@@ -117,7 +130,6 @@ where ...@@ -117,7 +130,6 @@ where
// Create a channel to communicate between the Python thread and the Rust async context // Create a channel to communicate between the Python thread and the Rust async context
let (tx, rx) = mpsc::channel::<Annotated<Resp>>(128); let (tx, rx) = mpsc::channel::<Annotated<Resp>>(128);
let tx_error = tx.clone();
let stream = Python::with_gil(|py| { let stream = Python::with_gil(|py| {
let py_request = pythonize(py, &request)?; let py_request = pythonize(py, &request)?;
...@@ -128,68 +140,84 @@ where ...@@ -128,68 +140,84 @@ where
let stream = Box::pin(stream); let stream = Box::pin(stream);
let process = |item: Result<Py<PyAny>, PyErr>| -> Result<Annotated<Resp>, Error> {
let item = item
.map_err(|err| error!("error processing python async generator stream: {}", err))?;
let response = Python::with_gil(|py| depythonize::<Resp>(&item.into_bound(py)))?;
let response = Annotated::from_data(response);
Ok(response)
};
// process the stream // process the stream
// any error thrown in the stream will be caught and complete the processing task // any error thrown in the stream will be caught and complete the processing task
// errors are captured by a task that is watching the processing task // errors are captured by a task that is watching the processing task
// the error will be emitted as an annotated error // the error will be emitted as an annotated error
let processor = tokio::spawn(async move { let request_id = id.clone();
log::trace!("processing stream from python async generator: {}", id);
tokio::spawn(async move {
log::debug!(
request_id,
"starting task to process python async generator stream"
);
let mut stream = stream; let mut stream = stream;
let mut count = 0;
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
// let mut done = false; count += 1;
let response = match process(item) { log::trace!(
request_id,
"processing the {}th item from python async generator",
count
);
let mut done = false;
let response = match process_item::<Resp>(item).await {
Ok(response) => response, Ok(response) => response,
Err(err) => { Err(e) => {
// done = true; done = true;
Annotated::from_error(err.to_string())
let msg = match &e {
ResponseProcessingError::DeserializeError(e) => {
// tell the python async generator to stop generating
// right now, this is impossible as we are not passing the context to the python async generator
// todo: add task-local context to the python async generator
// see: https://github.com/triton-inference-server/triton_distributed/issues/130
ctx.stop_generating();
let msg = format!("critical error: invalid response object from python async generator; application-logic-mismatch: {}", e);
log::error!(request_id, "{}", msg);
msg
}
ResponseProcessingError::PythonException(e) => {
let msg = format!("a python exception was caught while processing the async generator: {}", e);
log::warn!(request_id, "{}", msg);
msg
}
ResponseProcessingError::OffloadError(e) => {
let msg = format!("critical error: failed to offload the python async generator to a new thread: {}", e);
log::error!(request_id, "{}", msg);
msg
} }
}; };
if tx.send(response).await.is_err() { Annotated::from_error(msg)
log::error!("generator response channel was dropped: {}", id);
return Err(error!("generator response channel was dropped"));
} }
};
// if done { if tx.send(response).await.is_err() {
// break; log::trace!(
// } request_id,
"error forwarding annotated response to channel; channel is closed"
);
break;
} }
Result::<()>::Ok(()) if done {
}); log::debug!(
request_id,
tokio::spawn(async move { "early termination of python async generator stream task"
match processor.await {
Ok(Ok(_)) => {}
Ok(Err(err)) => {
log::error!("error processing python async generator: {}", err);
tx_error
.send(Annotated::from_error(err.to_string()))
.await
.unwrap();
}
Err(err) => {
log::error!(
"error on tokio task for processing python async generator stream: {}",
err
); );
tx_error break;
.send(Annotated::from_error(err.to_string()))
.await
.unwrap();
} }
} }
log::debug!(
request_id,
"finished processing python async generator stream"
);
}); });
let stream = ReceiverStream::new(rx); let stream = ReceiverStream::new(rx);
...@@ -197,3 +225,23 @@ where ...@@ -197,3 +225,23 @@ where
Ok(ResponseStream::new(Box::pin(stream), context.context())) Ok(ResponseStream::new(Box::pin(stream), context.context()))
} }
} }
async fn process_item<Resp>(
item: Result<Py<PyAny>, PyErr>,
) -> Result<Annotated<Resp>, ResponseProcessingError>
where
Resp: Data + for<'de> Deserialize<'de>,
{
let item = item.map_err(|e| ResponseProcessingError::PythonException(e.to_string()))?;
let response = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| depythonize::<Resp>(&item.into_bound(py)))
})
.await
.map_err(|e| ResponseProcessingError::OffloadError(e.to_string()))?
.map_err(|e| ResponseProcessingError::DeserializeError(e.to_string()))?;
let response = Annotated::from_data(response);
Ok(response)
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import random
import string
import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker
# Soak Test
#
# This was a failure case for the distributed runtime. If the Rust Tokio
# runtime is started with a small number of threads, it will starve the
# the GIL + asyncio event loop can starve timeout the ingress handler.
#
# There may still be some blocking operations in the ingress handler that
# could still eventually be a problem.
@triton_worker()
async def worker(runtime: DistributedRuntime):
ns = random_string()
task = asyncio.create_task(server_init(runtime, ns))
await client_init(runtime, ns)
runtime.shutdown()
await task
async def client_init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace(ns).component("backend").endpoint("generate")
# create client
client = await endpoint.client()
# wait for an endpoint to be ready
await client.wait_for_endpoints()
# Issue many concurrent requests to put load on the server,
# the task should issue the request and process the response
tasks = []
for i in range(10000):
tasks.append(asyncio.create_task(do_one(client)))
await asyncio.gather(*tasks)
# ensure all tasks are done and without errors
error_count = 0
for task in tasks:
if task.exception():
error_count += 1
assert error_count == 0, f"expected 0 errors, got {error_count}"
async def do_one(client):
stream = await client.generate("hello world")
async for char in stream:
pass
async def server_init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace(ns).component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
print("Started server instance")
await endpoint.serve_endpoint(RequestHandler().generate)
class RequestHandler:
"""
Request handler for the generate endpoint
"""
async def generate(self, request):
for char in request:
await asyncio.sleep(0.1)
yield char
def random_string(length=10):
chars = string.ascii_letters + string.digits # a-z, A-Z, 0-9
return "".join(random.choices(chars, k=length))
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
...@@ -67,7 +67,7 @@ pub struct RuntimeConfig { ...@@ -67,7 +67,7 @@ pub struct RuntimeConfig {
/// Maximum number of blocking threads /// Maximum number of blocking threads
/// Blocking threads are used for blocking operations, this value must be greater than 0. /// Blocking threads are used for blocking operations, this value must be greater than 0.
#[validate(range(min = 1))] #[validate(range(min = 1))]
#[builder(default = "16")] #[builder(default = "512")]
#[builder_field_attr(serde(skip_serializing_if = "Option::is_none"))] #[builder_field_attr(serde(skip_serializing_if = "Option::is_none"))]
pub max_blocking_threads: usize, pub max_blocking_threads: usize,
} }
...@@ -117,7 +117,10 @@ impl RuntimeConfig { ...@@ -117,7 +117,10 @@ impl RuntimeConfig {
impl Default for RuntimeConfig { impl Default for RuntimeConfig {
fn default() -> Self { fn default() -> Self {
Self::single_threaded() Self {
max_worker_threads: 16,
max_blocking_threads: 16,
}
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(feature = "integration")]
mod integration {
pub const DEFAULT_NAMESPACE: &str = "triton-init";
use futures::StreamExt;
use std::sync::Arc;
use triton_distributed::{
pipeline::{
async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut,
ResponseStream, SingleIn,
},
protocols::annotated::Annotated,
DistributedRuntime, Result, Runtime, Worker,
};
#[test]
fn main() -> Result<()> {
env_logger::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let server = tokio::spawn(backend(distributed.clone()));
let client = tokio::spawn(client(distributed.clone()));
client.await??;
distributed.shutdown();
server.await??;
Ok(())
}
struct RequestHandler {}
impl RequestHandler {
fn new() -> Arc<Self> {
Arc::new(Self {})
}
}
#[async_trait]
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for RequestHandler {
async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
let (data, ctx) = input.into_parts();
let chars = data
.chars()
.map(|c| Annotated::from_data(c.to_string()))
.collect::<Vec<_>>();
let stream = async_stream::stream! {
for c in chars {
yield c;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
};
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
async fn backend(runtime: DistributedRuntime) -> Result<()> {
// attach an ingress to an engine
let ingress = Ingress::for_engine(RequestHandler::new())?;
// // make the ingress discoverable via a component service
// // we must first create a service, then we can attach one more more endpoints
runtime
.namespace(DEFAULT_NAMESPACE)?
.component("backend")?
.service_builder()
.create()
.await?
.endpoint("generate")
.endpoint_builder()
.handler(ingress)
.start()
.await
}
async fn client(runtime: DistributedRuntime) -> Result<()> {
let client = runtime
.namespace(DEFAULT_NAMESPACE)?
.component("backend")?
.endpoint("generate")
.client::<String, Annotated<String>>()
.await?;
client.wait_for_endpoints().await?;
let client = Arc::new(client);
// spawn 20000 tasks to put load on the server
let mut tasks = Vec::new();
for _ in 0..20000 {
let client = client.clone();
tasks.push(tokio::spawn(async move {
let mut stream = client.random("hello world".to_string().into()).await?;
while let Some(_resp) = stream.next().await {}
Ok::<(), Error>(())
}));
}
for task in tasks.into_iter() {
task.await??;
}
Ok(())
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment