Unverified Commit d5f425ab authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(pipeline): Move migration outside of backend (#4823)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 7c15166d
...@@ -271,13 +271,13 @@ where ...@@ -271,13 +271,13 @@ where
// Link with prefill chooser including backward edge for response flow // Link with prefill chooser including backward edge for response flow
let engine = frontend let engine = frontend
.link(preprocessor_op.forward_edge())? .link(preprocessor_op.forward_edge())?
.link(backend.forward_edge())?
.link(migration.forward_edge())? .link(migration.forward_edge())?
.link(backend.forward_edge())?
.link(prefill_op.forward_edge())? .link(prefill_op.forward_edge())?
.link(service_backend)? .link(service_backend)?
.link(prefill_op.backward_edge())? .link(prefill_op.backward_edge())?
.link(migration.backward_edge())?
.link(backend.backward_edge())? .link(backend.backward_edge())?
.link(migration.backward_edge())?
.link(preprocessor_op.backward_edge())? .link(preprocessor_op.backward_edge())?
.link(frontend)?; .link(frontend)?;
......
...@@ -11,8 +11,8 @@ use async_nats::client::{ ...@@ -11,8 +11,8 @@ use async_nats::client::{
}; };
use crate::{ use crate::{
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard, preprocessor::BackendOutput,
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::PreprocessedRequest,
}; };
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -44,16 +44,16 @@ impl Migration { ...@@ -44,16 +44,16 @@ impl Migration {
impl impl
Operator< Operator<
SingleIn<PreprocessedRequest>, SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>, ManyOut<Annotated<BackendOutput>>,
SingleIn<PreprocessedRequest>, SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>, ManyOut<Annotated<BackendOutput>>,
> for Migration > for Migration
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>, next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> { ) -> Result<ManyOut<Annotated<BackendOutput>>> {
let (preprocessed_request, context) = request.transfer(()); let (preprocessed_request, context) = request.transfer(());
let engine_ctx = context.context(); let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone(); let engine_ctx_ = engine_ctx.clone();
...@@ -73,8 +73,8 @@ impl ...@@ -73,8 +73,8 @@ impl
struct RetryManager { struct RetryManager {
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
request: PreprocessedRequest, request: PreprocessedRequest,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>, next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>, next_stream: Option<ManyOut<Annotated<BackendOutput>>>,
retries_left: u32, retries_left: u32,
} }
...@@ -82,7 +82,7 @@ impl RetryManager { ...@@ -82,7 +82,7 @@ impl RetryManager {
pub async fn build( pub async fn build(
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
preprocessed_request: PreprocessedRequest, preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>, next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
retries_left: u32, retries_left: u32,
) -> Result<Self> { ) -> Result<Self> {
let mut slf = Self { let mut slf = Self {
...@@ -96,7 +96,7 @@ impl RetryManager { ...@@ -96,7 +96,7 @@ impl RetryManager {
Ok(slf) Ok(slf)
} }
pub async fn next(&mut self) -> Option<Annotated<LLMEngineOutput>> { pub async fn next(&mut self) -> Option<Annotated<BackendOutput>> {
loop { loop {
let response_stream = match self.next_stream.as_mut() { let response_stream = match self.next_stream.as_mut() {
Some(stream) => stream, Some(stream) => stream,
...@@ -128,7 +128,7 @@ impl RetryManager { ...@@ -128,7 +128,7 @@ impl RetryManager {
} }
async fn new_stream(&mut self) -> Result<()> { async fn new_stream(&mut self) -> Result<()> {
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None; let mut response_stream: Option<Result<ManyOut<Annotated<BackendOutput>>>> = None;
while self.retries_left > 0 { while self.retries_left > 0 {
self.retries_left -= 1; self.retries_left -= 1;
let request = Context::with_id(self.request.clone(), self.context.id().to_string()); let request = Context::with_id(self.request.clone(), self.context.id().to_string());
...@@ -162,7 +162,7 @@ impl RetryManager { ...@@ -162,7 +162,7 @@ impl RetryManager {
} }
} }
fn track_response(&mut self, response: &Annotated<LLMEngineOutput>) { fn track_response(&mut self, response: &Annotated<BackendOutput>) {
if self.retries_left == 0 { if self.retries_left == 0 {
return; return;
} }
...@@ -207,18 +207,17 @@ mod tests { ...@@ -207,18 +207,17 @@ mod tests {
} }
// Helper to create mock LLM engine output // Helper to create mock LLM engine output
fn create_mock_output(token_id: u32) -> Annotated<LLMEngineOutput> { fn create_mock_output(token_id: u32) -> Annotated<BackendOutput> {
Annotated::from_data(LLMEngineOutput { Annotated::from_data(BackendOutput {
token_ids: vec![token_id], token_ids: vec![token_id],
tokens: None, tokens: vec![],
text: Some(format!("token_{}", token_id)), text: Some(format!("token_{token_id}")),
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None,
completion_usage: None, completion_usage: None,
}) })
} }
...@@ -267,16 +266,13 @@ mod tests { ...@@ -267,16 +266,13 @@ mod tests {
#[async_trait] #[async_trait]
impl impl
AsyncEngine< AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, anyhow::Error>
SingleIn<PreprocessedRequest>, for MockEngine
ManyOut<Annotated<LLMEngineOutput>>,
anyhow::Error,
> for MockEngine
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> { ) -> Result<ManyOut<Annotated<BackendOutput>>> {
let call_num = self.call_count.fetch_add(1, Ordering::SeqCst); let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
let (preprocessed_request, context) = request.transfer(()); let (preprocessed_request, context) = request.transfer(());
...@@ -457,7 +453,7 @@ mod tests { ...@@ -457,7 +453,7 @@ mod tests {
&self, &self,
start: usize, start: usize,
end: usize, end: usize,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> { ) -> Result<ManyOut<Annotated<BackendOutput>>> {
let (tx, rx) = mpsc::channel(1); let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset; let token_offset = self.token_offset;
...@@ -494,7 +490,7 @@ mod tests { ...@@ -494,7 +490,7 @@ mod tests {
100, 100,
context_id.clone(), context_id.clone(),
)); ));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine; mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone())); let ctx = Arc::new(Controller::new(context_id.clone()));
...@@ -533,7 +529,7 @@ mod tests { ...@@ -533,7 +529,7 @@ mod tests {
100, 100,
context_id.clone(), context_id.clone(),
)); ));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine; mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone())); let ctx = Arc::new(Controller::new(context_id.clone()));
...@@ -573,7 +569,7 @@ mod tests { ...@@ -573,7 +569,7 @@ mod tests {
100, 100,
context_id.clone(), context_id.clone(),
)); ));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine; mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone())); let ctx = Arc::new(Controller::new(context_id.clone()));
...@@ -613,7 +609,7 @@ mod tests { ...@@ -613,7 +609,7 @@ mod tests {
100, 100,
context_id.clone(), context_id.clone(),
)); ));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine; mock_engine;
// Should fail to build due to initial stream creation failure after exhausting all 3 retries // Should fail to build due to initial stream creation failure after exhausting all 3 retries
...@@ -641,7 +637,7 @@ mod tests { ...@@ -641,7 +637,7 @@ mod tests {
100, 100,
context_id.clone(), context_id.clone(),
)); ));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine; mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone())); let ctx = Arc::new(Controller::new(context_id.clone()));
...@@ -690,7 +686,7 @@ mod tests { ...@@ -690,7 +686,7 @@ mod tests {
100, 100,
context_id.clone(), context_id.clone(),
)); ));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine; mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone())); let ctx = Arc::new(Controller::new(context_id.clone()));
...@@ -739,7 +735,7 @@ mod tests { ...@@ -739,7 +735,7 @@ mod tests {
100, 100,
context_id.clone(), context_id.clone(),
)); ));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> = let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine; mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone())); let ctx = Arc::new(Controller::new(context_id.clone()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment