Unverified Commit e1af3af6 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Remove static mode (#4235)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent d9b674b8
...@@ -81,13 +81,8 @@ pub async fn run( ...@@ -81,13 +81,8 @@ pub async fn run(
let dst_config = DistributedConfig { let dst_config = DistributedConfig {
store_backend: selected_store, store_backend: selected_store,
nats_config: nats::ClientOptions::default(), nats_config: nats::ClientOptions::default(),
is_static: flags.static_worker,
}; };
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?; let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
if let Some(Output::Static(path)) = &out_opt {
builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
}
let local_model = builder.build().await?; let local_model = builder.build().await?;
// //
...@@ -98,7 +93,7 @@ pub async fn run( ...@@ -98,7 +93,7 @@ pub async fn run(
print_cuda(&out_opt); print_cuda(&out_opt);
// Now that we know the output we're targeting, check if we expect it to work // Now that we know the output we're targeting, check if we expect it to work
flags.validate(&in_opt, &out_opt)?; flags.validate(&out_opt)?;
// Make an engine from the local_model, flags and output. // Make an engine from the local_model, flags and output.
let engine_config = engine_for( let engine_config = engine_for(
...@@ -128,20 +123,14 @@ async fn engine_for( ...@@ -128,20 +123,14 @@ async fn engine_for(
// Auto-discover backends // Auto-discover backends
Ok(EngineConfig::Dynamic(Box::new(local_model))) Ok(EngineConfig::Dynamic(Box::new(local_model)))
} }
Output::Static(_) => {
// A single static backend, no etcd
Ok(EngineConfig::StaticRemote(Box::new(local_model)))
}
Output::Echo => Ok(EngineConfig::StaticFull { Output::Echo => Ok(EngineConfig::StaticFull {
model: Box::new(local_model), model: Box::new(local_model),
engine: dynamo_llm::engines::make_echo_engine(), engine: dynamo_llm::engines::make_echo_engine(),
is_static: flags.static_worker,
}), }),
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
Output::MistralRs => Ok(EngineConfig::StaticFull { Output::MistralRs => Ok(EngineConfig::StaticFull {
engine: dynamo_engine_mistralrs::make_engine(&local_model).await?, engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
model: Box::new(local_model), model: Box::new(local_model),
is_static: flags.static_worker,
}), }),
Output::Mocker => { Output::Mocker => {
let args = flags.mocker_config(); let args = flags.mocker_config();
...@@ -153,7 +142,6 @@ async fn engine_for( ...@@ -153,7 +142,6 @@ async fn engine_for(
Ok(EngineConfig::StaticCore { Ok(EngineConfig::StaticCore {
engine, engine,
model: Box::new(local_model), model: Box::new(local_model),
is_static: flags.static_worker,
is_prefill: false, is_prefill: false,
}) })
} }
......
...@@ -25,7 +25,7 @@ Example: ...@@ -25,7 +25,7 @@ Example:
See `docs/guides/dynamo_run.md` in the repo for full details. See `docs/guides/dynamo_run.md` in the repo for full details.
"#; "#;
const USAGE: &str = "USAGE: dynamo-run in=[http|grpc|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|auto|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--context-length=N] [--kv-cache-block-size=16] [--extra-engine-args=args.json] [--static-worker] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--router-temperature=0.0] [--use-kv-events] [--max-num-batched-tokens=1.0] [--migration-limit=0] [--verbosity (-v|-vv)]"; const USAGE: &str = "USAGE: dynamo-run in=[http|grpc|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|auto|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--context-length=N] [--kv-cache-block-size=16] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--router-temperature=0.0] [--use-kv-events] [--max-num-batched-tokens=1.0] [--migration-limit=0] [--verbosity (-v|-vv)]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Set log level based on verbosity flag // Set log level based on verbosity flag
...@@ -138,5 +138,5 @@ fn is_in_dynamic(in_opt: &Input) -> bool { ...@@ -138,5 +138,5 @@ fn is_in_dynamic(in_opt: &Input) -> bool {
} }
fn is_out_dynamic(out_opt: &Option<Output>) -> bool { fn is_out_dynamic(out_opt: &Option<Output>) -> bool {
matches!(out_opt, Some(Output::Auto) | Some(Output::Static(_))) matches!(out_opt, Some(Output::Auto))
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::protocols::ENDPOINT_SCHEME;
use std::fmt; use std::fmt;
pub enum Output { pub enum Output {
...@@ -11,13 +10,6 @@ pub enum Output { ...@@ -11,13 +10,6 @@ pub enum Output {
/// Listen for models on nats/etcd, add/remove dynamically /// Listen for models on nats/etcd, add/remove dynamically
Auto, Auto,
/// Static remote: The dyn://namespace.component.endpoint name of a remote worker we expect to
/// exists. THIS DISABLES AUTO-DISCOVERY. Only this endpoint will be connected.
/// `--model-name and `--model-path` must also be set.
///
/// A static remote setup avoids having to run etcd.
Static(String),
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
MistralRs, MistralRs,
...@@ -37,11 +29,6 @@ impl TryFrom<&str> for Output { ...@@ -37,11 +29,6 @@ impl TryFrom<&str> for Output {
"dyn" | "auto" => Ok(Output::Auto), "dyn" | "auto" => Ok(Output::Auto),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Output::Static(path.to_string()))
}
e => Err(anyhow::anyhow!("Invalid out= option '{e}'")), e => Err(anyhow::anyhow!("Invalid out= option '{e}'")),
} }
} }
...@@ -57,7 +44,6 @@ impl fmt::Display for Output { ...@@ -57,7 +44,6 @@ impl fmt::Display for Output {
Output::Echo => "echo", Output::Echo => "echo",
Output::Auto => "auto", Output::Auto => "auto",
Output::Static(endpoint) => &format!("{ENDPOINT_SCHEME}{endpoint}"),
}; };
write!(f, "{s}") write!(f, "{s}")
} }
......
...@@ -943,7 +943,7 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -943,7 +943,7 @@ pub async fn create_worker_selection_pipeline_chat(
>, >,
> { > {
let runtime = Runtime::from_settings()?; let runtime = Runtime::from_settings()?;
let dst_config = DistributedConfig::from_settings(false); let dst_config = DistributedConfig::from_settings();
let drt_owned = DistributedRuntime::new(runtime, dst_config).await?; let drt_owned = DistributedRuntime::new(runtime, dst_config).await?;
let distributed_runtime: &'static DistributedRuntime = Box::leak(Box::new(drt_owned)); let distributed_runtime: &'static DistributedRuntime = Box::leak(Box::new(drt_owned));
......
...@@ -115,7 +115,7 @@ def parse_args(): ...@@ -115,7 +115,7 @@ def parse_args():
async def run(): async def run():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "etcd", False) runtime = DistributedRuntime(loop, "etcd")
args = parse_args() args = parse_args()
......
...@@ -20,7 +20,7 @@ import uvloop ...@@ -20,7 +20,7 @@ import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker(static=False) @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
await init(runtime, "dynamo") await init(runtime, "dynamo")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker(static=True)
async def worker(runtime: DistributedRuntime):
await init(runtime, "dynamo")
async def init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace(ns).component("backend").endpoint("generate")
# create client
client = await endpoint.client()
# wait for an endpoint to be ready
await client.wait_for_instances()
# issue request
stream = await client.generate("hello world")
# process the stream
async for char in stream:
print(char)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import random
import string
import uvloop
from client import init as client_init
from server import init as server_init
from dynamo.runtime import DistributedRuntime, dynamo_worker
def random_string(length=10):
chars = string.ascii_letters + string.digits # a-z, A-Z, 0-9
return "".join(random.choices(chars, k=length))
@dynamo_worker(static=True)
async def worker(runtime: DistributedRuntime):
ns = random_string()
task = asyncio.create_task(server_init(runtime, ns))
await asyncio.sleep(0.1) # let the server start
await client_init(runtime, ns)
runtime.shutdown()
await task
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
...@@ -21,7 +21,7 @@ class RequestHandler: ...@@ -21,7 +21,7 @@ class RequestHandler:
yield char yield char
@dynamo_worker(static=False) @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
......
...@@ -80,7 +80,7 @@ class RequestHandler: ...@@ -80,7 +80,7 @@ class RequestHandler:
num_output_tokens_so_far = next_total_toks num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False) @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args()) await init(runtime, cmd_line_args())
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Static version of server_sglang.py - see there for most details.
#
# The key differences between this and `server_sglang.py` are:
# - We do not call register_llm to advertise ourself in etcd. There is no etcd.
# - The frontend must know up-front all the details for the model: name, pre-processor path, and type.
#
# Window 1: `python server_sglang_static.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate --model-name "Qwen/Qwen3-0.6B" --model-path <hf_path> --model-type Backend
import argparse
import asyncio
import os
import sys
import sglang
import uvloop
from sglang.srt.server_args import ServerArgs
from dynamo.runtime import DistributedRuntime, dynamo_worker
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
DEFAULT_TEMPERATURE = 0.7
class Config:
"""Command line parameters or defaults"""
namespace: str
component: str
endpoint: str
model: str
class RequestHandler:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine):
self.engine_client = engine
async def generate(self, request):
# print(f"Received request: {request}")
sampling_params = {
"temperature": request["sampling_options"]["temperature"]
or DEFAULT_TEMPERATURE,
# sglang defaults this to 128
"max_new_tokens": request["stop_conditions"]["max_tokens"],
}
num_output_tokens_so_far = 0
gen = await self.engine_client.async_generate(
input_ids=request["token_ids"], sampling_params=sampling_params, stream=True
)
async for res in gen:
# res is a dict
finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
# Don't forward the stop token
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
next_total_toks = num_output_tokens_so_far
else:
next_total_toks = len(res["output_ids"])
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
yield out
num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=True)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
async def init(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
engine_args = ServerArgs(
model_path=config.model,
skip_tokenizer_init=True,
)
engine_client = sglang.Engine(server_args=engine_args)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
def cmd_line_args():
parser = argparse.ArgumentParser(
description="SGLang server integrated with Dynamo runtime."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--model",
type=str,
default=DEFAULT_MODEL,
help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
)
args = parser.parse_args()
config = Config()
config.model = args.model
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
print(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
return config
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
...@@ -93,7 +93,7 @@ class RequestHandler: ...@@ -93,7 +93,7 @@ class RequestHandler:
count = next_count count = next_count
@dynamo_worker(static=False) @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args()) await init(runtime, cmd_line_args())
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# #
# A very basic example of vllm worker handling pre-processed requests. # A very basic example of vllm worker handling pre-processed requests.
...@@ -103,7 +91,7 @@ class RequestHandler: ...@@ -103,7 +91,7 @@ class RequestHandler:
num_output_tokens_so_far = next_total_toks num_output_tokens_so_far = next_total_toks
@dynamo_worker(static=False) @dynamo_worker()
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args()) await init(runtime, cmd_line_args())
......
...@@ -432,7 +432,7 @@ enum ModelInput { ...@@ -432,7 +432,7 @@ enum ModelInput {
#[pymethods] #[pymethods]
impl DistributedRuntime { impl DistributedRuntime {
#[new] #[new]
fn new(event_loop: PyObject, store_kv: String, is_static: bool) -> PyResult<Self> { fn new(event_loop: PyObject, store_kv: String) -> PyResult<Self> {
let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?; let selected_kv_store: KeyValueStoreSelect = store_kv.parse().map_err(to_pyerr)?;
// Try to get existing runtime first, create new Worker only if needed // Try to get existing runtime first, create new Worker only if needed
...@@ -463,22 +463,14 @@ impl DistributedRuntime { ...@@ -463,22 +463,14 @@ impl DistributedRuntime {
}); });
} }
let inner = let runtime_config = DistributedConfig {
if is_static { store_backend: selected_kv_store,
runtime.secondary().block_on( nats_config: dynamo_runtime::transports::nats::ClientOptions::default(),
rs::DistributedRuntime::from_settings_without_discovery(runtime), };
) let inner = runtime
} else { .secondary()
let config = DistributedConfig { .block_on(rs::DistributedRuntime::new(runtime, runtime_config))
store_backend: selected_kv_store, .map_err(to_pyerr)?;
is_static: false,
nats_config: dynamo_runtime::transports::nats::ClientOptions::default(),
};
runtime
.secondary()
.block_on(rs::DistributedRuntime::new(runtime, config))
};
let inner = inner.map_err(to_pyerr)?;
Ok(DistributedRuntime { inner, event_loop }) Ok(DistributedRuntime { inner, event_loop })
} }
...@@ -867,11 +859,7 @@ impl Client { ...@@ -867,11 +859,7 @@ impl Client {
annotated: Option<bool>, annotated: Option<bool>,
context: Option<context::Context>, context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
if self.router.client.is_static() { self.random(py, request, annotated, context)
self.r#static(py, request, annotated, context)
} else {
self.random(py, request, annotated, context)
}
} }
/// Send a request to the next endpoint in a round-robin fashion. /// Send a request to the next endpoint in a round-robin fashion.
...@@ -991,45 +979,6 @@ impl Client { ...@@ -991,45 +979,6 @@ impl Client {
}) })
}) })
} }
/// Directly send a request to a pre-defined static worker
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn r#static<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let request_ctx = create_request_context(request, &context);
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
// Always instrument with appropriate span (none if no trace context)
let span = get_span_for_context(&context, "static");
client
.r#static(request_ctx)
.instrument(span)
.await
.map_err(to_pyerr)?
}
_ => client.r#static(request_ctx).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
annotated,
})
})
}
} }
async fn process_stream( async fn process_stream(
......
...@@ -24,7 +24,6 @@ pub enum EngineType { ...@@ -24,7 +24,6 @@ pub enum EngineType {
Echo = 1, Echo = 1,
Dynamic = 2, Dynamic = 2,
Mocker = 3, Mocker = 3,
Static = 4,
} }
#[pyclass] #[pyclass]
...@@ -246,11 +245,9 @@ async fn select_engine( ...@@ -246,11 +245,9 @@ async fn select_engine(
RsEngineConfig::StaticFull { RsEngineConfig::StaticFull {
model: Box::new(local_model), model: Box::new(local_model),
engine: dynamo_llm::engines::make_echo_engine(), engine: dynamo_llm::engines::make_echo_engine(),
is_static: false,
} }
} }
EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)), EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
EngineType::Static => RsEngineConfig::StaticRemote(Box::new(local_model)),
EngineType::Mocker => { EngineType::Mocker => {
let mocker_args = if let Some(extra_args_path) = args.extra_engine_args { let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| { MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
...@@ -279,7 +276,6 @@ async fn select_engine( ...@@ -279,7 +276,6 @@ async fn select_engine(
RsEngineConfig::StaticCore { RsEngineConfig::StaticCore {
engine, engine,
model: Box::new(local_model), model: Box::new(local_model),
is_static: false,
is_prefill: args.is_prefill, is_prefill: args.is_prefill,
} }
} }
......
...@@ -20,12 +20,12 @@ from dynamo._core import Namespace as Namespace ...@@ -20,12 +20,12 @@ from dynamo._core import Namespace as Namespace
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
def dynamo_worker(static=False): def dynamo_worker():
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "etcd", static) runtime = DistributedRuntime(loop, "etcd")
await func(runtime, *args, **kwargs) await func(runtime, *args, **kwargs)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
...@@ -32,6 +20,8 @@ class MockServer: ...@@ -32,6 +20,8 @@ class MockServer:
self.context_is_killed = False self.context_is_killed = False
async def generate(self, request, context): async def generate(self, request, context):
print("################## generate called ######################")
self.context_is_stopped = False self.context_is_stopped = False
self.context_is_killed = False self.context_is_killed = False
...@@ -127,7 +117,7 @@ class MockServer: ...@@ -127,7 +117,7 @@ class MockServer:
@pytest.fixture @pytest.fixture
def namespace(): def namespace():
"""Namespace for this test file""" """Namespace for this test file"""
return "cancellation_unit_test" return "cancellation-unit-test"
@pytest.fixture @pytest.fixture
...@@ -176,7 +166,7 @@ async def client(runtime, namespace): ...@@ -176,7 +166,7 @@ async def client(runtime, namespace):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_context_cancel(server, client): async def test_client_context_cancel(temp_file_store, server, client):
_, handler = server _, handler = server
context = Context() context = Context()
stream = await client.generate("_generate_until_context_cancelled", context=context) stream = await client.generate("_generate_until_context_cancelled", context=context)
...@@ -209,7 +199,7 @@ async def test_client_context_cancel(server, client): ...@@ -209,7 +199,7 @@ async def test_client_context_cancel(server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_loop_break(server, client): async def test_client_loop_break(temp_file_store, server, client):
_, handler = server _, handler = server
stream = await client.generate("_generate_until_context_cancelled") stream = await client.generate("_generate_until_context_cancelled")
...@@ -241,7 +231,7 @@ async def test_client_loop_break(server, client): ...@@ -241,7 +231,7 @@ async def test_client_loop_break(server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_context_cancel(server, client): async def test_server_context_cancel(temp_file_store, server, client):
_, handler = server _, handler = server
stream = await client.generate("_generate_and_cancel_context") stream = await client.generate("_generate_and_cancel_context")
...@@ -265,7 +255,7 @@ async def test_server_context_cancel(server, client): ...@@ -265,7 +255,7 @@ async def test_server_context_cancel(server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_raise_cancelled(server, client): async def test_server_raise_cancelled(temp_file_store, server, client):
_, handler = server _, handler = server
stream = await client.generate("_generate_and_raise_cancelled") stream = await client.generate("_generate_and_raise_cancelled")
...@@ -293,7 +283,7 @@ async def test_server_raise_cancelled(server, client): ...@@ -293,7 +283,7 @@ async def test_server_raise_cancelled(server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_context_already_cancelled(server, client): async def test_client_context_already_cancelled(temp_file_store, server, client):
_, handler = server _, handler = server
context = Context() context = Context()
context.stop_generating() context.stop_generating()
...@@ -315,7 +305,9 @@ async def test_client_context_already_cancelled(server, client): ...@@ -315,7 +305,9 @@ async def test_client_context_already_cancelled(server, client):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_context_cancel_before_await_request(server, client): async def test_client_context_cancel_before_await_request(
temp_file_store, server, client
):
_, handler = server _, handler = server
context = Context() context = Context()
request = client.generate("_generate_until_context_cancelled", context=context) request = client.generate("_generate_until_context_cancelled", context=context)
......
...@@ -97,9 +97,12 @@ def stop_process(name, process): ...@@ -97,9 +97,12 @@ def stop_process(name, process):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_direct_connection_cancellation(example_dir, server_process): async def test_direct_connection_cancellation(
temp_file_store, example_dir, server_process
):
"""Test cancellation with direct client-server connection""" """Test cancellation with direct client-server connection"""
# Run the client (direct connection) # Run the client (direct connection)
print(f"Key-value store dir: {temp_file_store}")
client_output = run_client(example_dir, use_middle=False) client_output = run_client(example_dir, use_middle=False)
# Wait for server to print cancellation message # Wait for server to print cancellation message
...@@ -119,10 +122,11 @@ async def test_direct_connection_cancellation(example_dir, server_process): ...@@ -119,10 +122,11 @@ async def test_direct_connection_cancellation(example_dir, server_process):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_middle_server_cancellation( async def test_middle_server_cancellation(
example_dir, server_process, middle_server_process temp_file_store, example_dir, server_process, middle_server_process
): ):
"""Test cancellation with middle server proxy""" """Test cancellation with middle server proxy"""
# Run the client (through middle server) # Run the client (through middle server)
print(f"Key-value store dir: {temp_file_store}")
client_output = run_client(example_dir, use_middle=True) client_output = run_client(example_dir, use_middle=True)
# Wait for server to print cancellation message # Wait for server to print cancellation message
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Test configuration and fixtures for Dynamo Python bindings tests. Test configuration and fixtures for Dynamo Python bindings tests.
...@@ -403,6 +391,17 @@ def nats_and_etcd(): ...@@ -403,6 +391,17 @@ def nats_and_etcd():
print(f"Error removing ETCD data dir: {e}") print(f"Error removing ETCD data dir: {e}")
@pytest.fixture(scope="function")
def temp_file_store():
"""
A temporary directory to use as the key-value store. Cleaned up on test exit.
Local to the unit test using it.
"""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["DYN_FILE_KV"] = tmpdir
yield tmpdir
@pytest.fixture(scope="function", autouse=False) @pytest.fixture(scope="function", autouse=False)
async def runtime(request): async def runtime(request):
""" """
...@@ -436,6 +435,6 @@ This is required because DistributedRuntime is a process-level singleton. ...@@ -436,6 +435,6 @@ This is required because DistributedRuntime is a process-level singleton.
) )
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "mem", True) runtime = DistributedRuntime(loop, "file")
yield runtime yield runtime
runtime.shutdown() runtime.shutdown()
...@@ -34,7 +34,7 @@ async def distributed_runtime(): ...@@ -34,7 +34,7 @@ async def distributed_runtime():
Each test gets its own runtime in a forked process to avoid singleton conflicts. Each test gets its own runtime in a forked process to avoid singleton conflicts.
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, "etcd", False) runtime = DistributedRuntime(loop, "etcd")
yield runtime yield runtime
runtime.shutdown() runtime.shutdown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment