Unverified Commit d5f425ab authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(pipeline): Move migration outside of backend (#4823)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 7c15166d
......@@ -271,13 +271,13 @@ where
// Link with prefill chooser including backward edge for response flow
let engine = frontend
.link(preprocessor_op.forward_edge())?
.link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(backend.forward_edge())?
.link(prefill_op.forward_edge())?
.link(service_backend)?
.link(prefill_op.backward_edge())?
.link(migration.backward_edge())?
.link(backend.backward_edge())?
.link(migration.backward_edge())?
.link(preprocessor_op.backward_edge())?
.link(frontend)?;
......
......@@ -11,8 +11,8 @@ use async_nats::client::{
};
use crate::{
model_card::ModelDeploymentCard,
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
model_card::ModelDeploymentCard, preprocessor::BackendOutput,
protocols::common::llm_backend::PreprocessedRequest,
};
use dynamo_runtime::{
......@@ -44,16 +44,16 @@ impl Migration {
impl
Operator<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
ManyOut<Annotated<BackendOutput>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
ManyOut<Annotated<BackendOutput>>,
> for Migration
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
) -> Result<ManyOut<Annotated<BackendOutput>>> {
let (preprocessed_request, context) = request.transfer(());
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
......@@ -73,8 +73,8 @@ impl
struct RetryManager {
context: Arc<dyn AsyncEngineContext>,
request: PreprocessedRequest,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
next_stream: Option<ManyOut<Annotated<BackendOutput>>>,
retries_left: u32,
}
......@@ -82,7 +82,7 @@ impl RetryManager {
pub async fn build(
context: Arc<dyn AsyncEngineContext>,
preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
retries_left: u32,
) -> Result<Self> {
let mut slf = Self {
......@@ -96,7 +96,7 @@ impl RetryManager {
Ok(slf)
}
pub async fn next(&mut self) -> Option<Annotated<LLMEngineOutput>> {
pub async fn next(&mut self) -> Option<Annotated<BackendOutput>> {
loop {
let response_stream = match self.next_stream.as_mut() {
Some(stream) => stream,
......@@ -128,7 +128,7 @@ impl RetryManager {
}
async fn new_stream(&mut self) -> Result<()> {
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
let mut response_stream: Option<Result<ManyOut<Annotated<BackendOutput>>>> = None;
while self.retries_left > 0 {
self.retries_left -= 1;
let request = Context::with_id(self.request.clone(), self.context.id().to_string());
......@@ -162,7 +162,7 @@ impl RetryManager {
}
}
fn track_response(&mut self, response: &Annotated<LLMEngineOutput>) {
fn track_response(&mut self, response: &Annotated<BackendOutput>) {
if self.retries_left == 0 {
return;
}
......@@ -207,18 +207,17 @@ mod tests {
}
// Helper to create mock LLM engine output
fn create_mock_output(token_id: u32) -> Annotated<LLMEngineOutput> {
Annotated::from_data(LLMEngineOutput {
fn create_mock_output(token_id: u32) -> Annotated<BackendOutput> {
Annotated::from_data(BackendOutput {
token_ids: vec![token_id],
tokens: None,
text: Some(format!("token_{}", token_id)),
tokens: vec![],
text: Some(format!("token_{token_id}")),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
})
}
......@@ -267,16 +266,13 @@ mod tests {
#[async_trait]
impl
AsyncEngine<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
anyhow::Error,
> for MockEngine
AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, anyhow::Error>
for MockEngine
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
) -> Result<ManyOut<Annotated<BackendOutput>>> {
let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
let (preprocessed_request, context) = request.transfer(());
......@@ -457,7 +453,7 @@ mod tests {
&self,
start: usize,
end: usize,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
) -> Result<ManyOut<Annotated<BackendOutput>>> {
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
......@@ -494,7 +490,7 @@ mod tests {
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
......@@ -533,7 +529,7 @@ mod tests {
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
......@@ -573,7 +569,7 @@ mod tests {
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
......@@ -613,7 +609,7 @@ mod tests {
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
// Should fail to build due to initial stream creation failure after exhausting all 3 retries
......@@ -641,7 +637,7 @@ mod tests {
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
......@@ -690,7 +686,7 @@ mod tests {
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
......@@ -739,7 +735,7 @@ mod tests {
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment