Unverified Commit 6ba64c31 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: tensor type for generic inference. (#2746)


Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: default avatarOlga Andreeva <124622579+oandreeva-nv@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent f9be2e9e
...@@ -165,6 +165,7 @@ fn register_llm<'p>( ...@@ -165,6 +165,7 @@ fn register_llm<'p>(
let model_input = match model_input { let model_input = match model_input {
ModelInput::Text => llm_rs::model_type::ModelInput::Text, ModelInput::Text => llm_rs::model_type::ModelInput::Text,
ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens, ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor,
}; };
let model_type_obj = model_type.inner; let model_type_obj = model_type.inner;
...@@ -298,6 +299,10 @@ impl ModelType { ...@@ -298,6 +299,10 @@ impl ModelType {
const Embedding: Self = ModelType { const Embedding: Self = ModelType {
inner: llm_rs::model_type::ModelType::Embedding, inner: llm_rs::model_type::ModelType::Embedding,
}; };
#[classattr]
const TensorBased: Self = ModelType {
inner: llm_rs::model_type::ModelType::TensorBased,
};
fn __or__(&self, other: &Self) -> Self { fn __or__(&self, other: &Self) -> Self {
ModelType { ModelType {
...@@ -315,6 +320,7 @@ impl ModelType { ...@@ -315,6 +320,7 @@ impl ModelType {
enum ModelInput { enum ModelInput {
Text = 1, Text = 1,
Tokens = 2, Tokens = 2,
Tensor = 3,
} }
#[pymethods] #[pymethods]
......
...@@ -52,6 +52,27 @@ impl ModelRuntimeConfig { ...@@ -52,6 +52,27 @@ impl ModelRuntimeConfig {
Ok(()) Ok(())
} }
fn set_tensor_model_config(
&mut self,
_py: Python<'_>,
tensor_model_config: &Bound<'_, PyDict>,
) -> PyResult<()> {
let tensor_model_config = pythonize::depythonize(tensor_model_config).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to convert tensor_model_config: {}", err))
})?;
self.inner.tensor_model_config = Some(tensor_model_config);
Ok(())
}
fn get_tensor_model_config(&self, _py: Python<'_>) -> PyResult<Option<PyObject>> {
if let Some(tensor_model_config) = &self.inner.tensor_model_config {
let py_obj = pythonize::pythonize(_py, tensor_model_config).map_err(to_pyerr)?;
Ok(Some(py_obj.unbind()))
} else {
Ok(None)
}
}
#[getter] #[getter]
fn total_kv_blocks(&self) -> Option<u64> { fn total_kv_blocks(&self) -> Option<u64> {
self.inner.total_kv_blocks self.inner.total_kv_blocks
......
...@@ -849,11 +849,11 @@ class HttpAsyncEngine: ...@@ -849,11 +849,11 @@ class HttpAsyncEngine:
... ...
class ModelInput: class ModelInput:
"""What type of request this model needs: Text or Tokens""" """What type of request this model needs: Text, Tokens or Tensor"""
... ...
class ModelType: class ModelType:
"""What type of request this model needs: Chat, Completions or Embedding""" """What type of request this model needs: Chat, Completions, Embedding or Tensor"""
... ...
class RouterMode: class RouterMode:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.
import os
import uvloop
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
TEST_END_TO_END = os.environ.get("TEST_END_TO_END", 0)
@dynamo_worker(static=False)
async def test_register(runtime: DistributedRuntime):
component = runtime.namespace("test").component("tensor")
await component.create_service()
endpoint = component.endpoint("generate")
model_config = {
"name": "tensor",
"inputs": [
{"name": "input_text", "data_type": "Bytes", "shape": [-1]},
{"name": "custom", "data_type": "Bytes", "shape": [-1]},
{"name": "streaming", "data_type": "Bool", "shape": [1]},
],
"outputs": [{"name": "output_text", "data_type": "Bytes", "shape": [-1]}],
}
runtime_config = ModelRuntimeConfig()
runtime_config.set_tensor_model_config(model_config)
assert model_config == runtime_config.get_tensor_model_config()
# [gluo FIXME] register_llm will attempt to load a LLM model,
# which is not well-defined for Tensor yet. Currently provide
# a valid model name to pass the registration.
await register_llm(
ModelInput.Tensor,
ModelType.TensorBased,
endpoint,
"Qwen/Qwen3-0.6B",
"tensor",
runtime_config=runtime_config,
)
if TEST_END_TO_END:
await endpoint.serve_endpoint(generate)
async def generate(request, context):
print(f"Received request: {request}")
# Echo input_text in output_text
output_text = None
streaming = False
for tensor in request["tensors"]:
if tensor["metadata"]["name"] == "input_text":
input_text_str = "".join(map(chr, tensor["data"]["values"][0]))
print(f"Input text: {input_text_str}")
output_text = tensor
output_text["metadata"]["name"] = "output_text"
if tensor["metadata"]["name"] == "streaming":
streaming = tensor["data"]["values"][0]
if output_text is None:
raise ValueError("input_text tensor not found in request")
if streaming:
for i in range(len(output_text["data"]["values"][0])):
chunk = {
"model": request["model"],
"tensors": [
{
"metadata": output_text["metadata"],
"data": {
"data_type": output_text["data"]["data_type"],
"values": [[output_text["data"]["values"][0][i]]],
},
}
],
}
yield chunk
else:
yield {"model": request["model"], "tensors": [output_text]}
if __name__ == "__main__":
uvloop.run(test_register())
...@@ -15,6 +15,7 @@ use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry}; ...@@ -15,6 +15,7 @@ use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry};
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector}; use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{ use crate::{
kv_router::KvRouter, kv_router::KvRouter,
types::generic::tensor::TensorStreamingEngine,
types::openai::{ types::openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
...@@ -36,6 +37,7 @@ pub struct ModelManager { ...@@ -36,6 +37,7 @@ pub struct ModelManager {
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>, completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>, chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>, embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These two are Mutex because we read and write rarely and equally // These two are Mutex because we read and write rarely and equally
entries: Mutex<HashMap<String, ModelEntry>>, entries: Mutex<HashMap<String, ModelEntry>>,
...@@ -54,6 +56,7 @@ impl ModelManager { ...@@ -54,6 +56,7 @@ impl ModelManager {
completion_engines: RwLock::new(ModelEngines::default()), completion_engines: RwLock::new(ModelEngines::default()),
chat_completion_engines: RwLock::new(ModelEngines::default()), chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()), embeddings_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()),
entries: Mutex::new(HashMap::new()), entries: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()), kv_choosers: Mutex::new(HashMap::new()),
} }
...@@ -73,6 +76,7 @@ impl ModelManager { ...@@ -73,6 +76,7 @@ impl ModelManager {
.into_iter() .into_iter()
.chain(self.list_completions_models()) .chain(self.list_completions_models())
.chain(self.list_embeddings_models()) .chain(self.list_embeddings_models())
.chain(self.list_tensor_models())
.collect() .collect()
} }
...@@ -88,6 +92,10 @@ impl ModelManager { ...@@ -88,6 +92,10 @@ impl ModelManager {
self.embeddings_engines.read().list() self.embeddings_engines.read().list()
} }
pub fn list_tensor_models(&self) -> Vec<String> {
self.tensor_engines.read().list()
}
pub fn add_completions_model( pub fn add_completions_model(
&self, &self,
model: &str, model: &str,
...@@ -115,6 +123,15 @@ impl ModelManager { ...@@ -115,6 +123,15 @@ impl ModelManager {
clients.add(model, engine) clients.add(model, engine)
} }
pub fn add_tensor_model(
&self,
model: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.add(model, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write(); let mut clients = self.completion_engines.write();
clients.remove(model) clients.remove(model)
...@@ -130,6 +147,11 @@ impl ModelManager { ...@@ -130,6 +147,11 @@ impl ModelManager {
clients.remove(model) clients.remove(model)
} }
pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.remove(model)
}
pub fn get_embeddings_engine( pub fn get_embeddings_engine(
&self, &self,
model: &str, model: &str,
...@@ -163,6 +185,17 @@ impl ModelManager { ...@@ -163,6 +185,17 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_tensor_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.tensor_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is /// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is
/// deleted from etcd. /// deleted from etcd.
pub fn save_model_entry(&self, key: &str, entry: ModelEntry) { pub fn save_model_entry(&self, key: &str, entry: ModelEntry) {
......
...@@ -33,6 +33,7 @@ use crate::{ ...@@ -33,6 +33,7 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
}, },
tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
}, },
}; };
...@@ -59,6 +60,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[ ...@@ -59,6 +60,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Chat, ModelType::Chat,
ModelType::Completions, ModelType::Completions,
ModelType::Embedding, ModelType::Embedding,
ModelType::TensorBased,
]; ];
impl ModelWatcher { impl ModelWatcher {
...@@ -213,10 +215,12 @@ impl ModelWatcher { ...@@ -213,10 +215,12 @@ impl ModelWatcher {
let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name); let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
let completions_model_remove_err = self.manager.remove_completions_model(&model_name); let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name); let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let mut chat_model_removed = false; let mut chat_model_removed = false;
let mut completions_model_removed = false; let mut completions_model_removed = false;
let mut embeddings_model_removed = false; let mut embeddings_model_removed = false;
let mut tensor_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() { if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true; chat_model_removed = true;
...@@ -228,20 +232,29 @@ impl ModelWatcher { ...@@ -228,20 +232,29 @@ impl ModelWatcher {
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() { if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
embeddings_model_removed = true; embeddings_model_removed = true;
} }
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true;
}
if !chat_model_removed && !completions_model_removed && !embeddings_model_removed { if !chat_model_removed
&& !completions_model_removed
&& !embeddings_model_removed
&& !tensor_model_removed
{
tracing::debug!( tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}", "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
model_name, model_name,
chat_model_removed, chat_model_removed,
completions_model_removed, completions_model_removed,
embeddings_model_removed embeddings_model_removed,
tensor_model_removed
); );
} else { } else {
for model_type in ALL_MODEL_TYPES { for model_type in ALL_MODEL_TYPES {
if ((chat_model_removed && *model_type == ModelType::Chat) if ((chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completions) || (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)) || (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (tensor_model_removed && *model_type == ModelType::TensorBased))
&& let Some(tx) = &self.model_update_tx && let Some(tx) = &self.model_update_tx
{ {
tx.send(ModelUpdate::Removed(*model_type)).await.ok(); tx.send(ModelUpdate::Removed(*model_type)).await.ok();
...@@ -421,11 +434,24 @@ impl ModelWatcher { ...@@ -421,11 +434,24 @@ impl ModelWatcher {
self.manager self.manager
.add_embeddings_model(&model_entry.name, embedding_engine)?; .add_embeddings_model(&model_entry.name, embedding_engine)?;
} else if model_entry.model_input == ModelInput::Tensor
&& model_entry.model_type.supports_tensor()
{
// Case 5: Tensor + Tensor (non-LLM)
let push_router = PushRouter::<
NvCreateTensorRequest,
Annotated<NvCreateTensorResponse>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?;
let engine = Arc::new(push_router);
self.manager.add_tensor_model(&model_entry.name, engine)?;
} else { } else {
// Reject unsupported combinations // Reject unsupported combinations
anyhow::bail!( anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \ "Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings", Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
model_entry.model_type, model_entry.model_type,
model_entry.model_input.as_str() model_entry.model_input.as_str()
); );
......
...@@ -3,3 +3,4 @@ ...@@ -3,3 +3,4 @@
pub mod kserve; pub mod kserve;
pub mod openai; pub mod openai;
pub mod tensor;
This diff is collapsed.
...@@ -15,12 +15,14 @@ use crate::protocols::openai::completions::{ ...@@ -15,12 +15,14 @@ use crate::protocols::openai::completions::{
use crate::types::Annotated; use crate::types::Annotated;
use super::kserve; use super::kserve;
use super::kserve::inference;
// [gluo NOTE] These are common utilities that should be shared between frontends // [gluo NOTE] These are common utilities that should be shared between frontends
use crate::http::service::{ use crate::http::service::{
disconnect::{ConnectionHandle, create_connection_monitor}, disconnect::{ConnectionHandle, create_connection_monitor},
metrics::{Endpoint, InflightGuard, process_response_and_observe_metrics}, metrics::{Endpoint, InflightGuard, process_response_and_observe_metrics},
}; };
use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt};
use tonic::Status; use tonic::Status;
...@@ -185,3 +187,205 @@ fn get_or_create_request_id(primary: Option<&str>) -> String { ...@@ -185,3 +187,205 @@ fn get_or_create_request_id(primary: Option<&str>) -> String {
let uuid = uuid::Uuid::new_v4(); let uuid = uuid::Uuid::new_v4();
uuid.to_string() uuid.to_string()
} }
impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
type Error = Status;
fn try_from(request: inference::ModelInferRequest) -> Result<Self, Self::Error> {
// Protocol requires if `raw_input_contents` is used to hold input data,
// it must be used for all inputs.
if !request.raw_input_contents.is_empty()
&& request.inputs.len() != request.raw_input_contents.len()
{
return Err(Status::invalid_argument(
"`raw_input_contents` must be used for all inputs",
));
}
// iterate through inputs
let mut text_input = None;
let mut stream = false;
for (idx, input) in request.inputs.iter().enumerate() {
match input.name.as_str() {
"text_input" => {
if input.datatype != "BYTES" {
return Err(Status::invalid_argument(format!(
"Expected 'text_input' to be of type BYTES for string input, got {:?}",
input.datatype
)));
}
if input.shape != vec![1] && input.shape != vec![1, 1] {
return Err(Status::invalid_argument(format!(
"Expected 'text_input' to have shape [1], got {:?}",
input.shape
)));
}
match &input.contents {
Some(content) => {
let bytes = content.bytes_contents.first().ok_or_else(|| {
Status::invalid_argument(
"'text_input' must contain exactly one element",
)
})?;
text_input = Some(String::from_utf8_lossy(bytes).to_string());
}
None => {
let raw_input =
request.raw_input_contents.get(idx).ok_or_else(|| {
Status::invalid_argument("Missing raw input for 'text_input'")
})?;
if raw_input.len() < 4 {
return Err(Status::invalid_argument(
"'text_input' raw input must be length-prefixed (>= 4 bytes)",
));
}
// We restrict the 'text_input' only contain one element, only need to
// parse the first element. Skip first four bytes that is used to store
// the length of the input.
text_input = Some(String::from_utf8_lossy(&raw_input[4..]).to_string());
}
}
}
"streaming" | "stream" => {
if input.datatype != "BOOL" {
return Err(Status::invalid_argument(format!(
"Expected '{}' to be of type BOOL, got {:?}",
input.name, input.datatype
)));
}
if input.shape != vec![1] {
return Err(Status::invalid_argument(format!(
"Expected 'stream' to have shape [1], got {:?}",
input.shape
)));
}
match &input.contents {
Some(content) => {
stream = *content.bool_contents.first().ok_or_else(|| {
Status::invalid_argument(
"'stream' must contain exactly one element",
)
})?;
}
None => {
let raw_input =
request.raw_input_contents.get(idx).ok_or_else(|| {
Status::invalid_argument("Missing raw input for 'stream'")
})?;
if raw_input.is_empty() {
return Err(Status::invalid_argument(
"'stream' raw input must contain at least one byte",
));
}
stream = raw_input[0] != 0;
}
}
}
_ => {
return Err(Status::invalid_argument(format!(
"Invalid input name: {}, supported inputs are 'text_input', 'stream'",
input.name
)));
}
}
}
// return error if text_input is None
let text_input = match text_input {
Some(input) => input,
None => {
return Err(Status::invalid_argument(
"Missing required input: 'text_input'",
));
}
};
Ok(NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: request.model_name,
prompt: Prompt::String(text_input),
stream: Some(stream),
user: if request.id.is_empty() {
None
} else {
Some(request.id.clone())
},
..Default::default()
},
common: Default::default(),
nvext: None,
})
}
}
impl TryFrom<NvCreateCompletionResponse> for inference::ModelInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
let mut outputs = vec![];
let mut text_output = vec![];
let mut finish_reason = vec![];
for choice in &response.inner.choices {
text_output.push(choice.text.clone());
let reason_str = match choice.finish_reason.as_ref() {
Some(CompletionFinishReason::Stop) => "stop",
Some(CompletionFinishReason::Length) => "length",
Some(CompletionFinishReason::ContentFilter) => "content_filter",
None => "",
};
finish_reason.push(reason_str.to_string());
}
outputs.push(inference::model_infer_response::InferOutputTensor {
name: "text_output".to_string(),
datatype: "BYTES".to_string(),
shape: vec![text_output.len() as i64],
contents: Some(inference::InferTensorContents {
bytes_contents: text_output
.into_iter()
.map(|text| text.as_bytes().to_vec())
.collect(),
..Default::default()
}),
..Default::default()
});
outputs.push(inference::model_infer_response::InferOutputTensor {
name: "finish_reason".to_string(),
datatype: "BYTES".to_string(),
shape: vec![finish_reason.len() as i64],
contents: Some(inference::InferTensorContents {
bytes_contents: finish_reason
.into_iter()
.map(|text| text.as_bytes().to_vec())
.collect(),
..Default::default()
}),
..Default::default()
});
Ok(inference::ModelInferResponse {
model_name: response.inner.model,
model_version: "1".to_string(),
id: response.inner.id,
outputs,
parameters: ::std::collections::HashMap::<String, inference::InferParameter>::new(),
raw_output_contents: vec![],
})
}
}
impl TryFrom<NvCreateCompletionResponse> for inference::ModelStreamInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
match inference::ModelInferResponse::try_from(response) {
Ok(response) => Ok(inference::ModelStreamInferResponse {
infer_response: Some(response),
..Default::default()
}),
Err(e) => Ok(inference::ModelStreamInferResponse {
infer_response: None,
error_message: format!("Failed to convert response: {}", e),
}),
}
}
}
This diff is collapsed.
...@@ -73,6 +73,9 @@ pub enum Endpoint { ...@@ -73,6 +73,9 @@ pub enum Endpoint {
/// OAI Responses /// OAI Responses
Responses, Responses,
/// Tensor
Tensor,
} }
/// Metrics for the HTTP service /// Metrics for the HTTP service
...@@ -456,6 +459,7 @@ impl std::fmt::Display for Endpoint { ...@@ -456,6 +459,7 @@ impl std::fmt::Display for Endpoint {
Endpoint::ChatCompletions => write!(f, "chat_completions"), Endpoint::ChatCompletions => write!(f, "chat_completions"),
Endpoint::Embeddings => write!(f, "embeddings"), Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Responses => write!(f, "responses"), Endpoint::Responses => write!(f, "responses"),
Endpoint::Tensor => write!(f, "tensor"),
} }
} }
} }
...@@ -467,6 +471,7 @@ impl Endpoint { ...@@ -467,6 +471,7 @@ impl Endpoint {
Endpoint::ChatCompletions => "chat_completions", Endpoint::ChatCompletions => "chat_completions",
Endpoint::Embeddings => "embeddings", Endpoint::Embeddings => "embeddings",
Endpoint::Responses => "responses", Endpoint::Responses => "responses",
Endpoint::Tensor => "tensor",
} }
} }
} }
......
...@@ -5,6 +5,8 @@ use std::collections::HashMap; ...@@ -5,6 +5,8 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::protocols::tensor;
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelRuntimeConfig { pub struct ModelRuntimeConfig {
pub total_kv_blocks: Option<u64>, pub total_kv_blocks: Option<u64>,
...@@ -20,6 +22,16 @@ pub struct ModelRuntimeConfig { ...@@ -20,6 +22,16 @@ pub struct ModelRuntimeConfig {
/// Mapping of engine-specific runtime configs /// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")] #[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>, pub runtime_data: HashMap<String, serde_json::Value>,
// Provide tensor model config in the case where the model type is Tensor.
// Currently use JSON object for convinence, the programmatic way is to
// define the model config struct as part of the tensor protocol and
// import it here.
// [gluo TODO] switch to ModelConfig if desired and workout a way to
// prepare it in a convinent way, the protobuf library used by tonic
// doesn't provide JSON parsing.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tensor_model_config: Option<tensor::TensorModelConfig>,
} }
impl ModelRuntimeConfig { impl ModelRuntimeConfig {
......
...@@ -15,6 +15,7 @@ bitflags! { ...@@ -15,6 +15,7 @@ bitflags! {
/// - `ModelType::Chat` /// - `ModelType::Chat`
/// - `ModelType::Completions` /// - `ModelType::Completions`
/// - `ModelType::Embedding` /// - `ModelType::Embedding`
/// - `ModelType::TensorBased`
/// ///
/// For example, a model that supports both chat and completions can be /// For example, a model that supports both chat and completions can be
/// expressed as: /// expressed as:
...@@ -34,6 +35,7 @@ bitflags! { ...@@ -34,6 +35,7 @@ bitflags! {
const Chat = 1 << 0; const Chat = 1 << 0;
const Completions = 1 << 1; const Completions = 1 << 1;
const Embedding = 1 << 2; const Embedding = 1 << 2;
const TensorBased = 1 << 3;
} }
} }
...@@ -51,6 +53,9 @@ impl ModelType { ...@@ -51,6 +53,9 @@ impl ModelType {
pub fn supports_embedding(&self) -> bool { pub fn supports_embedding(&self) -> bool {
self.contains(ModelType::Embedding) self.contains(ModelType::Embedding)
} }
pub fn supports_tensor(&self) -> bool {
self.contains(ModelType::TensorBased)
}
pub fn as_vec(&self) -> Vec<&'static str> { pub fn as_vec(&self) -> Vec<&'static str> {
let mut result = Vec::new(); let mut result = Vec::new();
...@@ -63,6 +68,9 @@ impl ModelType { ...@@ -63,6 +68,9 @@ impl ModelType {
if self.supports_embedding() { if self.supports_embedding() {
result.push("embedding"); result.push("embedding");
} }
if self.supports_tensor() {
result.push("tensor");
}
result result
} }
...@@ -79,6 +87,9 @@ impl ModelType { ...@@ -79,6 +87,9 @@ impl ModelType {
if self.contains(Self::Embedding) { if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding); endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
} }
// [gluo NOTE] ModelType::Tensor doesn't map to any endpoint type,
// current use of endpoint type is LLM specific and so does the HTTP
// server that uses it.
endpoint_types endpoint_types
} }
} }
...@@ -95,6 +106,8 @@ pub enum ModelInput { ...@@ -95,6 +106,8 @@ pub enum ModelInput {
Text, Text,
/// Pre-processed input /// Pre-processed input
Tokens, Tokens,
/// Tensor input
Tensor,
} }
impl ModelInput { impl ModelInput {
...@@ -102,6 +115,7 @@ impl ModelInput { ...@@ -102,6 +115,7 @@ impl ModelInput {
match self { match self {
Self::Text => "text", Self::Text => "text",
Self::Tokens => "tokens", Self::Tokens => "tokens",
Self::Tensor => "tensor",
} }
} }
} }
...@@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize}; ...@@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize};
pub mod codec; pub mod codec;
pub mod common; pub mod common;
pub mod openai; pub mod openai;
pub mod tensor;
/// The token ID type /// The token ID type
pub type TokenIdType = u32; pub type TokenIdType = u32;
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::protocols::Annotated;
use anyhow::Result;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use validator::Validate;
// [gluo TODO] whether it makes sense to have aggregator for tensor..
// we could if considering aggregation to be stacking the tensors by adding
// one more dimension. i.e. stream of [2, 2] tensors to be aggregated to
// [-1, 2, 2]. Will decide it later and currently do not allow aggregation.
// mod aggregator;
// pub use aggregator::DeltaAggregator;
// [gluo TODO] nvext is LLM specific, we really only use the annotation field
pub use super::openai::nvext::{NvExt, NvExtProvider};
#[derive(Debug, Serialize, Clone, Eq, PartialEq, Deserialize)]
pub enum DataType {
Bool,
Uint8,
Uint16,
Uint32,
Uint64,
Int8,
Int16,
Int32,
Int64,
Float32,
Float64,
Bytes,
}
impl DataType {
pub fn size(&self) -> usize {
match self {
DataType::Bool => size_of::<bool>(),
DataType::Uint8 => size_of::<u8>(),
DataType::Uint16 => size_of::<u16>(),
DataType::Uint32 => size_of::<u32>(),
DataType::Uint64 => size_of::<u64>(),
DataType::Int8 => size_of::<i8>(),
DataType::Int16 => size_of::<i16>(),
DataType::Int32 => size_of::<i32>(),
DataType::Int64 => size_of::<i64>(),
DataType::Float32 => size_of::<f32>(),
DataType::Float64 => size_of::<f64>(),
DataType::Bytes => 0, // variable length, return 0 as indicator
}
}
}
#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
// Self-describing encoding removes ambiguity between signed/unsigned and width variants.
#[serde(tag = "data_type", content = "values")]
pub enum FlattenTensor {
Bool(Vec<bool>),
// [gluo NOTE] f16, and bf16 is not stably supported
Uint8(Vec<u8>),
Uint16(Vec<u16>),
Uint32(Vec<u32>),
Uint64(Vec<u64>),
Int8(Vec<i8>),
Int16(Vec<i16>),
Int32(Vec<i32>),
Int64(Vec<i64>),
Float32(Vec<f32>),
Float64(Vec<f64>),
// Typically use to store string data, but really it can store
// arbitrary data such as serialized handles for custom worker behavior.
Bytes(Vec<Vec<u8>>),
}
#[allow(clippy::len_without_is_empty)]
impl FlattenTensor {
pub fn len(&self) -> usize {
match self {
Self::Bool(v) => v.len(),
Self::Uint8(v) => v.len(),
Self::Uint16(v) => v.len(),
Self::Uint32(v) => v.len(),
Self::Uint64(v) => v.len(),
Self::Int8(v) => v.len(),
Self::Int16(v) => v.len(),
Self::Int32(v) => v.len(),
Self::Int64(v) => v.len(),
Self::Float32(v) => v.len(),
Self::Float64(v) => v.len(),
Self::Bytes(v) => v.len(),
}
}
pub fn data_type(&self) -> DataType {
match self {
Self::Bool(_) => DataType::Bool,
Self::Uint8(_) => DataType::Uint8,
Self::Uint16(_) => DataType::Uint16,
Self::Uint32(_) => DataType::Uint32,
Self::Uint64(_) => DataType::Uint64,
Self::Int8(_) => DataType::Int8,
Self::Int16(_) => DataType::Int16,
Self::Int32(_) => DataType::Int32,
Self::Int64(_) => DataType::Int64,
Self::Float32(_) => DataType::Float32,
Self::Float64(_) => DataType::Float64,
Self::Bytes(_) => DataType::Bytes,
}
}
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)]
pub struct TensorMetadata {
pub name: String,
pub data_type: DataType,
pub shape: Vec<i64>,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)]
pub struct TensorModelConfig {
pub name: String,
pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Tensor {
pub metadata: TensorMetadata,
pub data: FlattenTensor,
}
impl validator::Validate for Tensor {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
use validator::{ValidationError, ValidationErrors};
let mut errs = ValidationErrors::new();
// dtype must match
if self.metadata.data_type != self.data.data_type() {
let mut e = ValidationError::new("dtype_mismatch");
e.message = Some("metadata.data_type does not match data variant".into());
errs.add("data_type", e);
}
let mut product: usize = 1;
for &d in &self.metadata.shape {
if d < 0 {
let mut e = ValidationError::new("negative_dim");
e.message = Some("only -1 is allowed as a wildcard dimension".into());
errs.add("shape", e);
break;
}
product = product.saturating_mul(d as usize);
}
// bytes payloads may be variable-length per item; enforce outer count only
let expect_count = self.data.len();
if product != expect_count {
let mut e = ValidationError::new("element_count_mismatch");
e.message = Some(
format!(
"shape implies {} elements but data has {}",
product, expect_count
)
.into(),
);
errs.add("shape", e);
}
if errs.is_empty() { Ok(()) } else { Err(errs) }
}
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateTensorRequest {
/// ID of the request
pub id: Option<String>,
/// ID of the model to use.
pub model: String,
/// Input tensors.
pub tensors: Vec<Tensor>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
/// A response structure for unary chat completion responses, embedding OpenAI's
/// `CreateChatCompletionResponse`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateTensorResponse {
/// ID of the corresponding request.
pub id: Option<String>,
/// ID of the model.
pub model: String,
/// Output tensors.
pub tensors: Vec<Tensor>,
}
/// Implements `NvExtProvider` for `NvCreateTensorRequest`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateTensorRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn raw_prompt(&self) -> Option<String> {
// Not really apply here.
None
}
}
/// Implements `AnnotationsProvider` for `NvCreateTensorRequest`,
/// enabling retrieval and management of request annotations.
impl AnnotationsProvider for NvCreateTensorRequest {
/// Retrieves the list of annotations from `NvExt`, if present.
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
/// Checks whether a specific annotation exists in the request.
///
/// # Arguments
/// * `annotation` - A string slice representing the annotation to check.
///
/// # Returns
/// `true` if the annotation exists, `false` otherwise.
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
pub struct DeltaAggregator {
response: Option<NvCreateTensorResponse>,
error: Option<String>,
}
impl NvCreateTensorResponse {
pub async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateTensorResponse>>,
) -> Result<NvCreateTensorResponse> {
let aggregator = stream
.fold(
DeltaAggregator {
response: None,
error: None,
},
|mut aggregator, delta| async move {
let delta = match delta.ok() {
Ok(delta) => delta,
Err(error) => {
if aggregator.error.is_none() {
aggregator.error = Some(error);
}
return aggregator;
}
};
match delta.data {
Some(resp) => {
if aggregator.response.is_none() {
aggregator.response = Some(resp);
} else if aggregator.error.is_none() {
aggregator.error =
Some("Multiple responses in non-streaming mode".to_string());
}
}
None => {
// Ignore metadata-only deltas in non-streaming mode.
}
}
aggregator
},
)
.await;
if let Some(error) = aggregator.error {
Err(anyhow::anyhow!(error))
} else if let Some(response) = aggregator.response {
Ok(response)
} else {
Err(anyhow::anyhow!("No response received"))
}
}
}
...@@ -60,3 +60,21 @@ pub mod openai { ...@@ -60,3 +60,21 @@ pub mod openai {
ServerStreamingEngine<NvCreateEmbeddingRequest, Annotated<NvCreateEmbeddingResponse>>; ServerStreamingEngine<NvCreateEmbeddingRequest, Annotated<NvCreateEmbeddingResponse>>;
} }
} }
pub mod generic {
use super::*;
use dynamo_runtime::pipeline::{ServerStreamingEngine, UnaryEngine};
pub mod tensor {
use super::*;
pub use protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
/// A [`UnaryEngine`] implementation for the generic Tensor API
pub type TensorUnaryEngine = UnaryEngine<NvCreateTensorRequest, NvCreateTensorResponse>;
/// A [`ServerStreamingEngine`] implementation for the generic Tensor API
pub type TensorStreamingEngine =
ServerStreamingEngine<NvCreateTensorRequest, Annotated<NvCreateTensorResponse>>;
}
}
...@@ -212,6 +212,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu ...@@ -212,6 +212,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu
Endpoint::ChatCompletions => 1, Endpoint::ChatCompletions => 1,
Endpoint::Embeddings => todo!(), Endpoint::Embeddings => todo!(),
Endpoint::Responses => todo!(), Endpoint::Responses => todo!(),
Endpoint::Tensor => todo!(),
}; };
let request_type = match request_type { let request_type = match request_type {
......
...@@ -6,6 +6,11 @@ pub mod kserve_test { ...@@ -6,6 +6,11 @@ pub mod kserve_test {
pub mod inference { pub mod inference {
tonic::include_proto!("inference"); tonic::include_proto!("inference");
} }
use dynamo_llm::discovery::ModelEntry;
use dynamo_llm::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_llm::model_type::{ModelInput, ModelType};
use dynamo_llm::protocols::tensor;
use dynamo_runtime::protocols::EndpointId;
use inference::grpc_inference_service_client::GrpcInferenceServiceClient; use inference::grpc_inference_service_client::GrpcInferenceServiceClient;
use inference::{ use inference::{
DataType, ModelConfigRequest, ModelInferRequest, ModelInferResponse, ModelMetadataRequest, DataType, ModelConfigRequest, ModelInferRequest, ModelInferResponse, ModelMetadataRequest,
...@@ -22,6 +27,7 @@ pub mod kserve_test { ...@@ -22,6 +27,7 @@ pub mod kserve_test {
}, },
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
}; };
use dynamo_runtime::{ use dynamo_runtime::{
CancellationToken, CancellationToken,
...@@ -181,6 +187,52 @@ pub mod kserve_test { ...@@ -181,6 +187,52 @@ pub mod kserve_test {
} }
} }
struct TensorEngine {}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateTensorRequest>,
ManyOut<Annotated<NvCreateTensorResponse>>,
Error,
> for TensorEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateTensorRequest>,
) -> Result<ManyOut<Annotated<NvCreateTensorResponse>>, Error> {
// Echo input tensor in response, additionally check if there is input tensor
// name "repeat", if so, send the same response as many time as the value of the tensor
let (request, context) = request.transfer(());
let ctx = context.context();
let repeat_count = request
.tensors
.iter()
.find_map(|t| {
if t.metadata.name == "repeat"
&& let tensor::FlattenTensor::Int32(data) = &t.data
&& !data.is_empty()
{
return Some(data[0]);
}
None
})
.unwrap_or(1);
let stream = async_stream::stream! {
for _ in 0..repeat_count {
yield Annotated::from_data(NvCreateTensorResponse {
id: request.id.clone(),
model: request.model.clone(),
tensors: request.tensors.clone(),
});
}
};
Ok(ResponseStream::new(Box::pin(stream), ctx))
}
}
/// Wait for the HTTP service to be ready by checking its health endpoint /// Wait for the HTTP service to be ready by checking its health endpoint
async fn get_ready_client(port: u16, timeout_secs: u64) -> GrpcInferenceServiceClient<Channel> { async fn get_ready_client(port: u16, timeout_secs: u64) -> GrpcInferenceServiceClient<Channel> {
let start = tokio::time::Instant::now(); let start = tokio::time::Instant::now();
...@@ -232,15 +284,58 @@ pub mod kserve_test { ...@@ -232,15 +284,58 @@ pub mod kserve_test {
manager manager
.add_completions_model("split", split.clone()) .add_completions_model("split", split.clone())
.unwrap(); .unwrap();
manager.save_model_entry(
"split",
ModelEntry {
name: "split".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "split".to_string(),
},
model_type: ModelType::Completions,
model_input: ModelInput::Text,
runtime_config: None,
},
);
manager manager
.add_chat_completions_model("failure", failure.clone()) .add_chat_completions_model("failure", failure.clone())
.unwrap(); .unwrap();
manager manager
.add_completions_model("failure", failure.clone()) .add_completions_model("failure", failure.clone())
.unwrap(); .unwrap();
manager.save_model_entry(
"failure",
ModelEntry {
name: "failure".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "failure".to_string(),
},
model_type: ModelType::Completions | ModelType::Chat,
model_input: ModelInput::Text,
runtime_config: None,
},
);
manager manager
.add_completions_model("long_running", long_running.clone()) .add_completions_model("long_running", long_running.clone())
.unwrap(); .unwrap();
manager.save_model_entry(
"long_running",
ModelEntry {
name: "long_running".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "long_running".to_string(),
},
model_type: ModelType::Completions,
model_input: ModelInput::Text,
runtime_config: None,
},
);
(service, split, failure, long_running) (service, split, failure, long_running)
} }
...@@ -276,6 +371,7 @@ pub mod kserve_test { ...@@ -276,6 +371,7 @@ pub mod kserve_test {
InferCancellation = 8992, InferCancellation = 8992,
StreamInferCancellation = 8993, StreamInferCancellation = 8993,
ModelInfo = 8994, ModelInfo = 8994,
TensorModel = 8995,
} }
#[rstest] #[rstest]
...@@ -474,8 +570,8 @@ pub mod kserve_test { ...@@ -474,8 +570,8 @@ pub mod kserve_test {
); );
assert_eq!( assert_eq!(
output.shape, output.shape,
vec![0], vec![1],
"Expected 'finish_reason' to have shape [0]" "Expected 'finish_reason' to have shape [1]"
); );
} }
_ => panic!("Unexpected output name: {}", output.name), _ => panic!("Unexpected output name: {}", output.name),
...@@ -632,8 +728,8 @@ pub mod kserve_test { ...@@ -632,8 +728,8 @@ pub mod kserve_test {
); );
assert_eq!( assert_eq!(
output.shape, output.shape,
vec![0], vec![1],
"Expected 'finish_reason' to have shape [0]" "Expected 'finish_reason' to have shape [1]"
); );
} }
_ => panic!("Unexpected output name: {}", output.name), _ => panic!("Unexpected output name: {}", output.name),
...@@ -724,8 +820,8 @@ pub mod kserve_test { ...@@ -724,8 +820,8 @@ pub mod kserve_test {
); );
assert_eq!( assert_eq!(
output.shape, output.shape,
vec![0], vec![1],
"Expected 'finish_reason' to have shape [0]" "Expected 'finish_reason' to have shape [1]"
); );
} }
_ => panic!("Unexpected output name: {}", output.name), _ => panic!("Unexpected output name: {}", output.name),
...@@ -1052,4 +1148,326 @@ pub mod kserve_test { ...@@ -1052,4 +1148,326 @@ pub mod kserve_test {
} }
} }
} }
#[rstest]
#[tokio::test]
async fn test_tensor_infer(
#[with(TestPort::TensorModel as u16)] service_with_engines: (
KserveService,
Arc<SplitEngine>,
Arc<AlwaysFailEngine>,
Arc<LongRunningEngine>,
),
text_input: inference::model_infer_request::InferInputTensor,
) {
// add tensor model
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", tensor.clone())
.unwrap();
// start server
let _running = RunningService::spawn(service_with_engines.0.clone());
let mut client = get_ready_client(TestPort::TensorModel as u16, 5).await;
let request = tonic::Request::new(ModelMetadataRequest {
name: "tensor".into(),
version: "".into(),
});
// Failure, model registered as Tensor but does not provide model config (in runtime config)
let entry = ModelEntry {
name: "tensor".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "endpoint".to_string(),
},
model_type: ModelType::TensorBased,
model_input: ModelInput::Tensor,
runtime_config: None,
};
service_with_engines
.0
.model_manager()
.save_model_entry("key", entry);
let response = client.model_metadata(request).await;
assert!(response.is_err());
let err = response.unwrap_err();
assert_eq!(
err.code(),
tonic::Code::InvalidArgument,
"Expected InvalidArgument error for unregistered model, get {}",
err
);
assert!(
err.message()
.contains("has type Tensor but no model config is provided"),
"Expected error message to contain 'has type Tensor but no model config is provided', got: {}",
err.message()
);
let request = tonic::Request::new(ModelConfigRequest {
name: "tensor".into(),
version: "".into(),
});
let response = client.model_config(request).await;
assert!(response.is_err());
let err = response.unwrap_err();
assert_eq!(
err.code(),
tonic::Code::InvalidArgument,
"Expected InvalidArgument error for unregistered model, get {}",
err
);
assert!(
err.message()
.contains("has type Tensor but no model config is provided"),
"Expected error message to contain 'has type Tensor but no model config is provided', got: {}",
err.message()
);
// Change model entry to have model config
service_with_engines
.0
.model_manager()
.remove_model_entry("key");
let entry = ModelEntry {
name: "tensor".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "endpoint".to_string(),
},
model_type: ModelType::TensorBased,
model_input: ModelInput::Tensor,
runtime_config: Some(ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
name: "tensor".to_string(),
inputs: vec![tensor::TensorMetadata {
name: "input".to_string(),
data_type: tensor::DataType::Bytes,
shape: vec![1],
}],
outputs: vec![tensor::TensorMetadata {
name: "output".to_string(),
data_type: tensor::DataType::Bool,
shape: vec![-1],
}],
}),
..Default::default()
}),
};
service_with_engines
.0
.model_manager()
.save_model_entry("key", entry);
// Success
let request = tonic::Request::new(ModelMetadataRequest {
name: "tensor".into(),
version: "".into(),
});
let response = client.model_metadata(request).await.unwrap();
assert_eq!(
response.get_ref().name,
"tensor",
"Expected response of the same model name",
);
// input
for io in &response.get_ref().inputs {
match io.name.as_str() {
"input" => {
assert_eq!(
io.datatype, "BYTES",
"Expected 'input' to have datatype 'BYTES'"
);
assert_eq!(io.shape, vec![1], "Expected 'input' to have shape [1]");
}
_ => panic!("Unexpected output name: {}", io.name),
}
}
// output
for io in &response.get_ref().outputs {
match io.name.as_str() {
"output" => {
assert_eq!(
io.datatype, "BOOL",
"Expected 'output' to have datatype 'BOOL'"
);
assert_eq!(io.shape, vec![-1], "Expected 'output' to have shape [-1]");
}
_ => panic!("Unexpected output name: {}", io.name),
}
}
let model_name = "tensor";
let inputs = vec![text_input.clone()];
let request = tonic::Request::new(ModelInferRequest {
model_name: model_name.into(),
model_version: "1".into(),
id: "1234".into(),
inputs: inputs.clone(),
..Default::default()
});
let response = client.model_infer(request).await.unwrap();
validate_tensor_response(response, model_name, inputs);
// streaming response in model_infer(), expect failure
let repeat = inference::model_infer_request::InferInputTensor {
name: "repeat".into(),
datatype: "INT32".into(),
shape: vec![1],
contents: Some(inference::InferTensorContents {
int_contents: vec![2],
..Default::default()
}),
..Default::default()
};
let inputs = vec![text_input.clone(), repeat.clone()];
let request = tonic::Request::new(ModelInferRequest {
model_name: model_name.into(),
model_version: "1".into(),
id: "1234".into(),
inputs: inputs.clone(),
..Default::default()
});
let response = client.model_infer(request).await;
assert!(response.is_err());
let err = response.unwrap_err();
assert_eq!(
err.code(),
tonic::Code::Internal,
"Expected Internal error for trying to stream response in ModelInfer, get {}",
err
);
// assert "stream" in error message
assert!(
err.message()
.contains("Multiple responses in non-streaming mode"),
"Expected error message to contain 'Multiple responses in non-streaming mode', got: {}",
err.message()
);
// model_stream_infer() and raw_input_contents
{
let inputs = vec![text_input.clone(), repeat.clone()];
let outbound = async_stream::stream! {
let request_count = 1;
for _ in 0..request_count {
let mut text_input = text_input.clone();
text_input.contents = None; // Clear contents to use raw_input_contents
let text_input_str = "dummy input";
let input_len = text_input_str.len() as u32;
let mut serialized_text_input = input_len.to_le_bytes().to_vec();
serialized_text_input.extend_from_slice(text_input_str.as_bytes());
let mut repeat = repeat.clone();
repeat.contents = None; // Clear contents to use raw_input_contents
let serialized_repeat = 2i32.to_le_bytes().to_vec();
let request = ModelInferRequest {
model_name: model_name.into(),
model_version: "1".into(),
id: "1234".into(),
inputs: vec![text_input.clone(), repeat.clone()],
raw_input_contents: vec![serialized_text_input, serialized_repeat],
..Default::default()
};
yield request;
}
};
let response = client
.model_stream_infer(Request::new(outbound))
.await
.unwrap();
let mut inbound = response.into_inner();
let mut response_idx = 0;
while let Some(response) = inbound.message().await.unwrap() {
assert!(
response.error_message.is_empty(),
"Expected successful inference"
);
assert!(
response.infer_response.is_some(),
"Expected successful inference"
);
if let Some(response) = &response.infer_response {
validate_tensor_response(
Response::new(response.clone()),
model_name,
inputs.clone(),
);
}
response_idx += 1;
}
assert_eq!(response_idx, 2, "Expected 2 responses")
}
}
fn validate_tensor_response(
response: Response<ModelInferResponse>,
model_name: &str,
inputs: Vec<inference::model_infer_request::InferInputTensor>,
) {
assert_eq!(
response.get_ref().model_name,
model_name,
"Expected response of the same model name",
);
assert_eq!(
response.get_ref().model_version,
"1",
"Expected response of the same model version"
);
assert_eq!(
response.get_ref().id,
"1234",
"Expected response of the same request ID"
);
assert_eq!(
response.get_ref().outputs.len(),
inputs.len(),
"Expected the same number of outputs as inputs",
);
for output in &response.get_ref().outputs {
let mut found = false;
for input in &inputs {
if input.name != output.name {
continue;
}
assert_eq!(
output.name, input.name,
"Expected output name to be '{}', got '{}'",
input.name, output.name
);
assert_eq!(
output.datatype, input.datatype,
"Expected output datatype to be '{}', got '{}'",
input.datatype, output.datatype
);
assert_eq!(
output.shape, input.shape,
"Expected output shape to be '{:?}', got '{:?}'",
input.shape, output.shape
);
found = true;
break;
}
if !found {
panic!("Unexpected output name: {}", output.name);
}
}
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment