Unverified Commit fdcf611f authored by Jacky's avatar Jacky Committed by GitHub
Browse files

chore: Add Request Migration docs and minor enhancements (#2038)

parent bbe8dbb2
...@@ -4,3 +4,18 @@ Usage: ...@@ -4,3 +4,18 @@ Usage:
- `pip install -r requirements.txt` # Need a recent pip, `uv pip` might be too old. - `pip install -r requirements.txt` # Need a recent pip, `uv pip` might be too old.
- `python -m dynamo.llama_cpp --model-path /data/models/Qwen3-0.6B-Q8_0.gguf [args]` - `python -m dynamo.llama_cpp --model-path /data/models/Qwen3-0.6B-Q8_0.gguf [args]`
## Request Migration
In a [Distributed System](#distributed-system), a request may fail due to connectivity issues between the Frontend and the Backend.
The Frontend will automatically track which Backends are having connectivity issues with it and avoid routing new requests to the Backends with known connectivity issues.
For ongoing requests, there is a `--migration-limit` flag which can be set on the Backend that tells the Frontend how many times a request can be migrated to another Backend should there be a loss of connectivity to the current Backend.
For example,
```bash
python3 -m dynamo.llama_cpp ... --migration-limit=3
```
indicates a request to this model may be migrated up to 3 times to another Backend, before failing the request, should the Frontend detects a connectivity issue to the current Backend.
The migrated request will continue responding to the original request, allowing for a seamless transition between Backends, and a reduced overall request failure rate at the Frontend for enhanced user experience.
...@@ -29,6 +29,7 @@ class Config: ...@@ -29,6 +29,7 @@ class Config:
model_path: str model_path: str
model_name: Optional[str] model_name: Optional[str]
context_length: int context_length: int
migration_limit: int
@dynamo_worker(static=False) @dynamo_worker(static=False)
...@@ -40,7 +41,13 @@ async def worker(runtime: DistributedRuntime): ...@@ -40,7 +41,13 @@ async def worker(runtime: DistributedRuntime):
model_type = ModelType.Chat # llama.cpp does the pre-processing model_type = ModelType.Chat # llama.cpp does the pre-processing
endpoint = component.endpoint(config.endpoint) endpoint = component.endpoint(config.endpoint)
await register_llm(model_type, endpoint, config.model_path, config.model_name) await register_llm(
model_type,
endpoint,
config.model_path,
config.model_name,
migration_limit=config.migration_limit,
)
# Initialize the engine # Initialize the engine
# For more parameters see: # For more parameters see:
...@@ -100,6 +107,12 @@ def cmd_line_args(): ...@@ -100,6 +107,12 @@ def cmd_line_args():
default=None, default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.", help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
) )
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
args = parser.parse_args() args = parser.parse_args()
config = Config() config = Config()
...@@ -124,6 +137,7 @@ def cmd_line_args(): ...@@ -124,6 +137,7 @@ def cmd_line_args():
config.component = parsed_component_name config.component = parsed_component_name
config.endpoint = parsed_endpoint_name config.endpoint = parsed_endpoint_name
config.context_length = args.context_length config.context_length = args.context_length
config.migration_limit = args.migration_limit
return config return config
......
...@@ -139,6 +139,22 @@ cd $DYNAMO_ROOT/components/backends/sglang ...@@ -139,6 +139,22 @@ cd $DYNAMO_ROOT/components/backends/sglang
./launch/disagg_dp_attn.sh ./launch/disagg_dp_attn.sh
``` ```
## Request Migration
In a [Distributed System](#distributed-system), a request may fail due to connectivity issues between the Frontend and the Backend.
The Frontend will automatically track which Backends are having connectivity issues with it and avoid routing new requests to the Backends with known connectivity issues.
For ongoing requests, there is a `--migration-limit` flag which can be set on the Backend that tells the Frontend how many times a request can be migrated to another Backend should there be a loss of connectivity to the current Backend.
For example,
```bash
python3 -m dynamo.sglang ... --migration-limit=3
```
indicates a request to this model may be migrated up to 3 times to another Backend, before failing the request, should the Frontend detects a connectivity issue to the current Backend.
The migrated request will continue responding to the original request, allowing for a seamless transition between Backends, and a reduced overall request failure rate at the Frontend for enhanced user experience.
## Advanced Examples ## Advanced Examples
Below we provide a selected list of advanced examples. Please open up an issue if you'd like to see a specific example! Below we provide a selected list of advanced examples. Please open up an issue if you'd like to see a specific example!
......
...@@ -311,11 +311,23 @@ async def worker(runtime: DistributedRuntime): ...@@ -311,11 +311,23 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers set up for graceful shutdown") logging.info("Signal handlers set up for graceful shutdown")
server_args = parse_sglang_args_inc(sys.argv[1:]) # TODO: Better handle non-sglang args
await init(runtime, server_args) sys_argv = sys.argv[1:]
migration_limit = 0
try:
async def init(runtime: DistributedRuntime, server_args: ServerArgs): idx = sys_argv.index("--migration-limit")
migration_limit = int(sys_argv[idx + 1])
del sys_argv[idx : idx + 2] # Remove the args from sys_argv
except Exception:
pass
server_args = parse_sglang_args_inc(sys_argv)
await init(runtime, server_args, migration_limit)
async def init(
runtime: DistributedRuntime, server_args: ServerArgs, migration_limit: int
):
"""Initialize worker (either prefill or aggregated)""" """Initialize worker (either prefill or aggregated)"""
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
...@@ -330,6 +342,7 @@ async def init(runtime: DistributedRuntime, server_args: ServerArgs): ...@@ -330,6 +342,7 @@ async def init(runtime: DistributedRuntime, server_args: ServerArgs):
server_args.model_path, server_args.model_path,
server_args.served_model_name, server_args.served_model_name,
kv_cache_block_size=server_args.page_size, kv_cache_block_size=server_args.page_size,
migration_limit=migration_limit,
) )
if server_args.disaggregation_mode != "null": if server_args.disaggregation_mode != "null":
......
...@@ -205,6 +205,22 @@ DISAGGREGATION_STRATEGY="prefill_first" ./launch/disagg.sh ...@@ -205,6 +205,22 @@ DISAGGREGATION_STRATEGY="prefill_first" ./launch/disagg.sh
Dynamo with TensorRT-LLM supports two methods for transferring KV cache in disaggregated serving: UCX (default) and NIXL (experimental). For detailed information and configuration instructions for each method, see the [KV cache transfer guide](./kv-cache-tranfer.md). Dynamo with TensorRT-LLM supports two methods for transferring KV cache in disaggregated serving: UCX (default) and NIXL (experimental). For detailed information and configuration instructions for each method, see the [KV cache transfer guide](./kv-cache-tranfer.md).
## Request Migration
In a [Distributed System](#distributed-system), a request may fail due to connectivity issues between the Frontend and the Backend.
The Frontend will automatically track which Backends are having connectivity issues with it and avoid routing new requests to the Backends with known connectivity issues.
For ongoing requests, there is a `--migration-limit` flag which can be set on the Backend that tells the Frontend how many times a request can be migrated to another Backend should there be a loss of connectivity to the current Backend.
For example,
```bash
python3 -m dynamo.trtllm ... --migration-limit=3
```
indicates a request to this model may be migrated up to 3 times to another Backend, before failing the request, should the Frontend detects a connectivity issue to the current Backend.
The migrated request will continue responding to the original request, allowing for a seamless transition between Backends, and a reduced overall request failure rate at the Frontend for enhanced user experience.
## More Example Architectures ## More Example Architectures
- [Llama 4 Maverick Instruct + Eagle Speculative Decoding](./llama4_plus_eagle.md) - [Llama 4 Maverick Instruct + Eagle Speculative Decoding](./llama4_plus_eagle.md)
...@@ -137,6 +137,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -137,6 +137,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.model_path, config.model_path,
config.served_model_name, config.served_model_name,
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
) )
# publisher will be set later if publishing is enabled. # publisher will be set later if publishing is enabled.
......
...@@ -28,6 +28,7 @@ class Config: ...@@ -28,6 +28,7 @@ class Config:
self.served_model_name: Optional[str] = None self.served_model_name: Optional[str] = None
self.tensor_parallel_size: int = 1 self.tensor_parallel_size: int = 1
self.kv_block_size: int = 32 self.kv_block_size: int = 32
self.migration_limit: int = 0
self.extra_engine_args: str = "" self.extra_engine_args: str = ""
self.publish_events_and_metrics: bool = False self.publish_events_and_metrics: bool = False
self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE
...@@ -46,6 +47,7 @@ class Config: ...@@ -46,6 +47,7 @@ class Config:
f"tensor_parallel_size={self.tensor_parallel_size}, " f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, " f"kv_block_size={self.kv_block_size}, "
f"extra_engine_args={self.extra_engine_args}, " f"extra_engine_args={self.extra_engine_args}, "
f"migration_limit={self.migration_limit}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, " f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, " f"disaggregation_mode={self.disaggregation_mode}, "
f"disaggregation_strategy={self.disaggregation_strategy}, " f"disaggregation_strategy={self.disaggregation_strategy}, "
...@@ -113,6 +115,12 @@ def cmd_line_args(): ...@@ -113,6 +115,12 @@ def cmd_line_args():
parser.add_argument( parser.add_argument(
"--kv-block-size", type=int, default=32, help="Size of a KV cache block." "--kv-block-size", type=int, default=32, help="Size of a KV cache block."
) )
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument( parser.add_argument(
"--extra-engine-args", "--extra-engine-args",
...@@ -188,6 +196,7 @@ def cmd_line_args(): ...@@ -188,6 +196,7 @@ def cmd_line_args():
config.tensor_parallel_size = args.tensor_parallel_size config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size config.kv_block_size = args.kv_block_size
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics config.publish_events_and_metrics = args.publish_events_and_metrics
......
...@@ -186,3 +186,19 @@ vLLM workers are configured through command-line arguments. Key parameters inclu ...@@ -186,3 +186,19 @@ vLLM workers are configured through command-line arguments. Key parameters inclu
See `args.py` for the full list of configuration options and their defaults. See `args.py` for the full list of configuration options and their defaults.
The [documentation](https://docs.vllm.ai/en/v0.9.2/configuration/serve_args.html?h=serve+arg) for the vLLM CLI args points to running 'vllm serve --help' to see what CLI args can be added. We use the same argument parser as vLLM. The [documentation](https://docs.vllm.ai/en/v0.9.2/configuration/serve_args.html?h=serve+arg) for the vLLM CLI args points to running 'vllm serve --help' to see what CLI args can be added. We use the same argument parser as vLLM.
## Request Migration
In a [Distributed System](#distributed-system), a request may fail due to connectivity issues between the Frontend and the Backend.
The Frontend will automatically track which Backends are having connectivity issues with it and avoid routing new requests to the Backends with known connectivity issues.
For ongoing requests, there is a `--migration-limit` flag which can be set on the Backend that tells the Frontend how many times a request can be migrated to another Backend should there be a loss of connectivity to the current Backend.
For example,
```bash
python3 -m dynamo.vllm ... --migration-limit=3
```
indicates a request to this model may be migrated up to 3 times to another Backend, before failing the request, should the Frontend detects a connectivity issue to the current Backend.
The migrated request will continue responding to the original request, allowing for a seamless transition between Backends, and a reduced overall request failure rate at the Frontend for enhanced user experience.
...@@ -31,6 +31,7 @@ class Config: ...@@ -31,6 +31,7 @@ class Config:
component: str component: str
endpoint: str endpoint: str
is_prefill_worker: bool is_prefill_worker: bool
migration_limit: int = 0
kv_port: Optional[int] = None kv_port: Optional[int] = None
side_channel_port: Optional[int] = None side_channel_port: Optional[int] = None
...@@ -57,6 +58,12 @@ def parse_args() -> Config: ...@@ -57,6 +58,12 @@ def parse_args() -> Config:
action="store_true", action="store_true",
help="Enable prefill functionality for this worker. Uses the provided namespace to construct dyn://namespace.prefill.generate", help="Enable prefill functionality for this worker. Uses the provided namespace to construct dyn://namespace.prefill.generate",
) )
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -102,6 +109,7 @@ def parse_args() -> Config: ...@@ -102,6 +109,7 @@ def parse_args() -> Config:
config.endpoint = parsed_endpoint_name config.endpoint = parsed_endpoint_name
config.engine_args = engine_args config.engine_args = engine_args
config.is_prefill_worker = args.is_prefill_worker config.is_prefill_worker = args.is_prefill_worker
config.migration_limit = args.migration_limit
if config.engine_args.block_size is None: if config.engine_args.block_size is None:
config.engine_args.block_size = 16 config.engine_args.block_size = 16
......
...@@ -148,6 +148,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -148,6 +148,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.model, config.model,
config.served_model_name, config.served_model_name,
kv_cache_block_size=config.engine_args.block_size, kv_cache_block_size=config.engine_args.block_size,
migration_limit=config.migration_limit,
) )
factory = StatLoggerFactory(component, config.engine_args.data_parallel_rank or 0) factory = StatLoggerFactory(component, config.engine_args.data_parallel_rank or 0)
......
...@@ -209,6 +209,22 @@ The KV-aware routing arguments: ...@@ -209,6 +209,22 @@ The KV-aware routing arguments:
- `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events. - `--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events.
### Request Migration
In a [Distributed System](#distributed-system), a request may fail due to connectivity issues between the HTTP Server and the Worker Engine.
The HTTP Server will automatically track which Worker Engines are having connectivity issues with it and avoid routing new requests to the Engines with known connectivity issues.
For ongoing requests, there is a `--migration-limit` flag which can be set on the Worker Engines that tells the HTTP Server how many times a request can be migrated to another Engine should there be a loss of connectivity to the current Engine.
For example,
```bash
dynamo-run in=dyn://... out=vllm ... --migration-limit=3
```
indicates a request to this model may be migrated up to 3 times to another Engine, before failing the request, should the HTTP Server detects a connectivity issue to the current Engine.
The migrated request will continue responding to the original request, allowing for a seamless transition between Engines, and a reduced overall request failure rate at the HTTP Server for enhanced user experience.
## Full usage details ## Full usage details
The `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features. The `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features.
......
...@@ -24,7 +24,7 @@ Example: ...@@ -24,7 +24,7 @@ Example:
- OR: ./dynamo-run /data/models/Llama-3.2-1B-Instruct-Q4_K_M.gguf - OR: ./dynamo-run /data/models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
"#; "#;
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--kv-cache-block-size=16] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--kv-gpu-cache-usage-weight=1.0] [--kv-waiting-requests-weight=1.0] [--verbosity (-v|-vv)]"; const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--kv-cache-block-size=16] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--kv-gpu-cache-usage-weight=1.0] [--kv-waiting-requests-weight=1.0] [--migration-limit=0] [--verbosity (-v|-vv)]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// Set log level based on verbosity flag // Set log level based on verbosity flag
......
...@@ -113,9 +113,9 @@ impl RetryManager { ...@@ -113,9 +113,9 @@ impl RetryManager {
if let Some(err) = response.err() { if let Some(err) = response.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed"; const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG { if format!("{:?}", err) == STREAM_ERR_MSG {
tracing::info!("Stream disconnected... recreating stream..."); tracing::warn!("Stream disconnected... recreating stream...");
if let Err(err) = self.new_stream().await { if let Err(err) = self.new_stream().await {
tracing::info!("Cannot recreate stream: {:?}", err); tracing::warn!("Cannot recreate stream: {:?}", err);
} else { } else {
continue; continue;
} }
...@@ -138,7 +138,7 @@ impl RetryManager { ...@@ -138,7 +138,7 @@ impl RetryManager {
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() { if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() { if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
if matches!(req_err.kind(), NatsNoResponders) { if matches!(req_err.kind(), NatsNoResponders) {
tracing::info!("Creating new stream... retrying..."); tracing::warn!("Creating new stream... retrying...");
continue; continue;
} }
} }
...@@ -150,9 +150,9 @@ impl RetryManager { ...@@ -150,9 +150,9 @@ impl RetryManager {
self.next_stream = Some(next_stream); self.next_stream = Some(next_stream);
Ok(()) Ok(())
} }
Some(Err(err)) => Err(err), // should propagate streaming error if stream started Some(Err(err)) => Err(err), // should propagate original error if any
None => Err(Error::msg( None => Err(Error::msg(
"Retries exhausted - should propagate streaming error", "Migration limit exhausted", // should propagate original error if any
)), )),
} }
} }
...@@ -165,6 +165,10 @@ impl RetryManager { ...@@ -165,6 +165,10 @@ impl RetryManager {
Some(output) => output, Some(output) => output,
None => return, None => return,
}; };
if let Some(max_tokens) = self.request.stop_conditions.max_tokens {
self.request.stop_conditions.max_tokens =
Some(max_tokens.saturating_sub(llm_engine_output.token_ids.len() as u32));
}
for token_id in llm_engine_output.token_ids.iter() { for token_id in llm_engine_output.token_ids.iter() {
self.request.token_ids.push(*token_id); self.request.token_ids.push(*token_id);
} }
...@@ -181,11 +185,14 @@ mod tests { ...@@ -181,11 +185,14 @@ mod tests {
use tokio::sync::mpsc; use tokio::sync::mpsc;
// Helper to create a mock preprocessed request // Helper to create a mock preprocessed request
fn create_mock_request() -> PreprocessedRequest { fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
PreprocessedRequest { PreprocessedRequest {
token_ids: vec![1, 2, 3], token_ids: vec![1, 2, 3],
batch_token_ids: None, batch_token_ids: None,
stop_conditions: StopConditions::default(), stop_conditions: StopConditions {
max_tokens: Some(max_tokens),
..Default::default()
},
sampling_options: SamplingOptions::default(), sampling_options: SamplingOptions::default(),
eos_token_ids: vec![], eos_token_ids: vec![],
mdc_sum: None, mdc_sum: None,
...@@ -264,9 +271,18 @@ mod tests { ...@@ -264,9 +271,18 @@ mod tests {
.token_ids .token_ids
.len() .len()
.saturating_sub(initial_tokens); .saturating_sub(initial_tokens);
let _responses_remaining = self
.num_responses // Assert that max_tokens reflects the expected remaining tokens
.saturating_sub(responses_already_generated); let expected_max_tokens =
self.num_responses
.saturating_sub(responses_already_generated) as u32;
assert_eq!(
preprocessed_request.stop_conditions.max_tokens,
Some(expected_max_tokens),
"max_tokens should be {} but got {:?}",
expected_max_tokens,
preprocessed_request.stop_conditions.max_tokens
);
match &self.behavior { match &self.behavior {
MockBehavior::Success => { MockBehavior::Success => {
...@@ -454,7 +470,7 @@ mod tests { ...@@ -454,7 +470,7 @@ mod tests {
/// Expected behavior: All 10 responses should be received successfully. /// Expected behavior: All 10 responses should be received successfully.
#[tokio::test] #[tokio::test]
async fn test_retry_manager_no_migration() { async fn test_retry_manager_no_migration() {
let request = create_mock_request(); let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::Success, 10, 100)); let mock_engine = Arc::new(MockEngine::new(MockBehavior::Success, 10, 100));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine; mock_engine;
...@@ -485,7 +501,7 @@ mod tests { ...@@ -485,7 +501,7 @@ mod tests {
/// Expected behavior: All 10 responses should be received successfully after retry. /// Expected behavior: All 10 responses should be received successfully after retry.
#[tokio::test] #[tokio::test]
async fn test_retry_manager_new_request_migration() { async fn test_retry_manager_new_request_migration() {
let request = create_mock_request(); let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::FailThenSuccess, 10, 100)); let mock_engine = Arc::new(MockEngine::new(MockBehavior::FailThenSuccess, 10, 100));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine; mock_engine;
...@@ -516,7 +532,7 @@ mod tests { ...@@ -516,7 +532,7 @@ mod tests {
/// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total. /// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total.
#[tokio::test] #[tokio::test]
async fn test_retry_manager_ongoing_request_migration() { async fn test_retry_manager_ongoing_request_migration() {
let request = create_mock_request(); let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new( let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFail { fail_after: 5 }, MockBehavior::MidStreamFail { fail_after: 5 },
10, 10,
...@@ -552,7 +568,7 @@ mod tests { ...@@ -552,7 +568,7 @@ mod tests {
/// Expected behavior: Should receive an error after all retries are exhausted, with the original error. /// Expected behavior: Should receive an error after all retries are exhausted, with the original error.
#[tokio::test] #[tokio::test]
async fn test_retry_manager_new_request_migration_indefinite_failure() { async fn test_retry_manager_new_request_migration_indefinite_failure() {
let request = create_mock_request(); let request = create_mock_request(0);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::AlwaysFail, 0, 100)); let mock_engine = Arc::new(MockEngine::new(MockBehavior::AlwaysFail, 0, 100));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine; mock_engine;
...@@ -572,7 +588,7 @@ mod tests { ...@@ -572,7 +588,7 @@ mod tests {
/// Expected behavior: Should receive some responses from first stream, then error after retries exhausted. /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
#[tokio::test] #[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure() { async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
let request = create_mock_request(); let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new( let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlways { fail_after: 3 }, MockBehavior::MidStreamFailAlways { fail_after: 3 },
10, 10,
...@@ -619,7 +635,7 @@ mod tests { ...@@ -619,7 +635,7 @@ mod tests {
/// Expected behavior: Should receive some responses from first stream, then error after retries exhausted. /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
#[tokio::test] #[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() { async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
let request = create_mock_request(); let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new( let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 }, MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
10, 10,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment