Unverified Commit 6ba64c31 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: tensor type for generic inference. (#2746)


Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: default avatarOlga Andreeva <124622579+oandreeva-nv@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent f9be2e9e
......@@ -165,6 +165,7 @@ fn register_llm<'p>(
let model_input = match model_input {
ModelInput::Text => llm_rs::model_type::ModelInput::Text,
ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor,
};
let model_type_obj = model_type.inner;
......@@ -298,6 +299,10 @@ impl ModelType {
const Embedding: Self = ModelType {
inner: llm_rs::model_type::ModelType::Embedding,
};
#[classattr]
const TensorBased: Self = ModelType {
inner: llm_rs::model_type::ModelType::TensorBased,
};
fn __or__(&self, other: &Self) -> Self {
ModelType {
......@@ -315,6 +320,7 @@ impl ModelType {
enum ModelInput {
Text = 1,
Tokens = 2,
Tensor = 3,
}
#[pymethods]
......
......@@ -52,6 +52,27 @@ impl ModelRuntimeConfig {
Ok(())
}
fn set_tensor_model_config(
&mut self,
_py: Python<'_>,
tensor_model_config: &Bound<'_, PyDict>,
) -> PyResult<()> {
let tensor_model_config = pythonize::depythonize(tensor_model_config).map_err(|err| {
PyErr::new::<PyException, _>(format!("Failed to convert tensor_model_config: {}", err))
})?;
self.inner.tensor_model_config = Some(tensor_model_config);
Ok(())
}
fn get_tensor_model_config(&self, _py: Python<'_>) -> PyResult<Option<PyObject>> {
if let Some(tensor_model_config) = &self.inner.tensor_model_config {
let py_obj = pythonize::pythonize(_py, tensor_model_config).map_err(to_pyerr)?;
Ok(Some(py_obj.unbind()))
} else {
Ok(None)
}
}
#[getter]
fn total_kv_blocks(&self) -> Option<u64> {
self.inner.total_kv_blocks
......
......@@ -849,11 +849,11 @@ class HttpAsyncEngine:
...
class ModelInput:
"""What type of request this model needs: Text or Tokens"""
"""What type of request this model needs: Text, Tokens or Tensor"""
...
class ModelType:
"""What type of request this model needs: Chat, Completions or Embedding"""
"""What type of request this model needs: Chat, Completions, Embedding or Tensor"""
...
class RouterMode:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.
import os
import uvloop
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
TEST_END_TO_END = os.environ.get("TEST_END_TO_END", 0)
@dynamo_worker(static=False)
async def test_register(runtime: DistributedRuntime):
component = runtime.namespace("test").component("tensor")
await component.create_service()
endpoint = component.endpoint("generate")
model_config = {
"name": "tensor",
"inputs": [
{"name": "input_text", "data_type": "Bytes", "shape": [-1]},
{"name": "custom", "data_type": "Bytes", "shape": [-1]},
{"name": "streaming", "data_type": "Bool", "shape": [1]},
],
"outputs": [{"name": "output_text", "data_type": "Bytes", "shape": [-1]}],
}
runtime_config = ModelRuntimeConfig()
runtime_config.set_tensor_model_config(model_config)
assert model_config == runtime_config.get_tensor_model_config()
# [gluo FIXME] register_llm will attempt to load a LLM model,
# which is not well-defined for Tensor yet. Currently provide
# a valid model name to pass the registration.
await register_llm(
ModelInput.Tensor,
ModelType.TensorBased,
endpoint,
"Qwen/Qwen3-0.6B",
"tensor",
runtime_config=runtime_config,
)
if TEST_END_TO_END:
await endpoint.serve_endpoint(generate)
async def generate(request, context):
print(f"Received request: {request}")
# Echo input_text in output_text
output_text = None
streaming = False
for tensor in request["tensors"]:
if tensor["metadata"]["name"] == "input_text":
input_text_str = "".join(map(chr, tensor["data"]["values"][0]))
print(f"Input text: {input_text_str}")
output_text = tensor
output_text["metadata"]["name"] = "output_text"
if tensor["metadata"]["name"] == "streaming":
streaming = tensor["data"]["values"][0]
if output_text is None:
raise ValueError("input_text tensor not found in request")
if streaming:
for i in range(len(output_text["data"]["values"][0])):
chunk = {
"model": request["model"],
"tensors": [
{
"metadata": output_text["metadata"],
"data": {
"data_type": output_text["data"]["data_type"],
"values": [[output_text["data"]["values"][0][i]]],
},
}
],
}
yield chunk
else:
yield {"model": request["model"], "tensors": [output_text]}
if __name__ == "__main__":
uvloop.run(test_register())
......@@ -15,6 +15,7 @@ use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry};
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{
kv_router::KvRouter,
types::generic::tensor::TensorStreamingEngine,
types::openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
......@@ -36,6 +37,7 @@ pub struct ModelManager {
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These two are Mutex because we read and write rarely and equally
entries: Mutex<HashMap<String, ModelEntry>>,
......@@ -54,6 +56,7 @@ impl ModelManager {
completion_engines: RwLock::new(ModelEngines::default()),
chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()),
entries: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()),
}
......@@ -73,6 +76,7 @@ impl ModelManager {
.into_iter()
.chain(self.list_completions_models())
.chain(self.list_embeddings_models())
.chain(self.list_tensor_models())
.collect()
}
......@@ -88,6 +92,10 @@ impl ModelManager {
self.embeddings_engines.read().list()
}
pub fn list_tensor_models(&self) -> Vec<String> {
self.tensor_engines.read().list()
}
pub fn add_completions_model(
&self,
model: &str,
......@@ -115,6 +123,15 @@ impl ModelManager {
clients.add(model, engine)
}
pub fn add_tensor_model(
&self,
model: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.add(model, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write();
clients.remove(model)
......@@ -130,6 +147,11 @@ impl ModelManager {
clients.remove(model)
}
pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.remove(model)
}
pub fn get_embeddings_engine(
&self,
model: &str,
......@@ -163,6 +185,17 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_tensor_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.tensor_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is
/// deleted from etcd.
pub fn save_model_entry(&self, key: &str, entry: ModelEntry) {
......
......@@ -33,6 +33,7 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
},
tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
},
};
......@@ -59,6 +60,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Chat,
ModelType::Completions,
ModelType::Embedding,
ModelType::TensorBased,
];
impl ModelWatcher {
......@@ -213,10 +215,12 @@ impl ModelWatcher {
let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let mut chat_model_removed = false;
let mut completions_model_removed = false;
let mut embeddings_model_removed = false;
let mut tensor_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true;
......@@ -228,20 +232,29 @@ impl ModelWatcher {
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
embeddings_model_removed = true;
}
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true;
}
if !chat_model_removed && !completions_model_removed && !embeddings_model_removed {
if !chat_model_removed
&& !completions_model_removed
&& !embeddings_model_removed
&& !tensor_model_removed
{
tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}",
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
model_name,
chat_model_removed,
completions_model_removed,
embeddings_model_removed
embeddings_model_removed,
tensor_model_removed
);
} else {
for model_type in ALL_MODEL_TYPES {
if ((chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding))
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (tensor_model_removed && *model_type == ModelType::TensorBased))
&& let Some(tx) = &self.model_update_tx
{
tx.send(ModelUpdate::Removed(*model_type)).await.ok();
......@@ -421,11 +434,24 @@ impl ModelWatcher {
self.manager
.add_embeddings_model(&model_entry.name, embedding_engine)?;
} else if model_entry.model_input == ModelInput::Tensor
&& model_entry.model_type.supports_tensor()
{
// Case 5: Tensor + Tensor (non-LLM)
let push_router = PushRouter::<
NvCreateTensorRequest,
Annotated<NvCreateTensorResponse>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?;
let engine = Arc::new(push_router);
self.manager.add_tensor_model(&model_entry.name, engine)?;
} else {
// Reject unsupported combinations
anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings",
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
model_entry.model_type,
model_entry.model_input.as_str()
);
......
......@@ -3,3 +3,4 @@
pub mod kserve;
pub mod openai;
pub mod tensor;
......@@ -11,10 +11,10 @@ use crate::http::service::Metrics;
use crate::http::service::metrics;
use crate::discovery::ModelManager;
use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
use crate::request_template::RequestTemplate;
use anyhow::Result;
use derive_builder::Builder;
use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt};
use dynamo_runtime::transports::etcd;
use futures::pin_mut;
use tokio::task::JoinHandle;
......@@ -22,6 +22,8 @@ use tokio_stream::{Stream, StreamExt};
use tokio_util::sync::CancellationToken;
use crate::grpc::service::openai::completion_response_stream;
use crate::grpc::service::tensor::tensor_response_stream;
use std::convert::{TryFrom, TryInto};
use tonic::{Request, Response, Status, transport::Server};
use crate::protocols::openai::completions::{
......@@ -33,8 +35,8 @@ pub mod inference {
}
use inference::grpc_inference_service_server::{GrpcInferenceService, GrpcInferenceServiceServer};
use inference::{
InferParameter, ModelConfig, ModelConfigRequest, ModelConfigResponse, ModelInferRequest,
ModelInferResponse, ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
ModelConfig, ModelConfigRequest, ModelConfigResponse, ModelInferRequest, ModelInferResponse,
ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
};
/// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint
......@@ -79,6 +81,10 @@ impl State {
pub fn etcd_client(&self) -> Option<&etcd::Client> {
self.etcd_client.as_ref()
}
fn is_tensor_model(&self, model: &String) -> bool {
self.manager.list_tensor_models().contains(model)
}
}
#[derive(Clone)]
......@@ -180,8 +186,34 @@ impl GrpcInferenceService for KserveService {
&self,
request: Request<ModelInferRequest>,
) -> Result<Response<ModelInferResponse>, Status> {
let model = request.get_ref().model_name.clone();
let request = request.into_inner();
let request_id = request.id.clone();
// [gluo TODO] refactor to reuse code, inference logic is largely the same
if self.state().is_tensor_model(&model) {
// Fallback handling by assuming the model is OpenAI Completions model
let tensor_request: NvCreateTensorRequest = NvCreateTensorRequest::try_from(request)
.map_err(|e| Status::invalid_argument(format!("Failed to parse request: {}", e)))?;
let stream = tensor_response_stream(self.state_clone(), tensor_request, false).await?;
let tensor_response = NvCreateTensorResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!("Failed to fold completions stream: {:?}", e);
Status::internal(format!("Failed to fold completions stream: {}", e))
})?;
let mut reply: ModelInferResponse = tensor_response.try_into().map_err(|e| {
Status::invalid_argument(format!("Failed to parse response: {}", e))
})?;
reply.id = request_id;
return Ok(Response::new(reply));
}
// Fallback handling by assuming the model is OpenAI Completions model
let mut completion_request: NvCreateCompletionRequest = request
.try_into()
.map_err(|e| Status::invalid_argument(format!("Failed to parse request: {}", e)))?;
......@@ -216,13 +248,12 @@ impl GrpcInferenceService for KserveService {
.await
.map_err(|e| {
tracing::error!("Failed to fold completions stream: {:?}", e);
Status::internal("Failed to fold completions stream")
Status::internal(format!("Failed to fold completions stream: {}", e))
})?;
let mut reply: ModelInferResponse = completion_response
.try_into()
.map_err(|e| Status::invalid_argument(format!("Failed to parse response: {}", e)))?;
reply.id = request_id;
Ok(Response::new(reply))
......@@ -244,9 +275,7 @@ impl GrpcInferenceService for KserveService {
// and passing AsyncEngineStream for each request to the response stream
// which will be collectively polling.
while let Some(request) = request_stream.next().await {
// Must keep track of 'request_id' which will be returned in corresponding response
let request_id: String;
let mut completion_request: NvCreateCompletionRequest = match request {
let request = match request {
Err(e) => {
tracing::error!("Unexpected gRPC failed to read request: {}", e);
yield ModelStreamInferResponse {
......@@ -256,13 +285,49 @@ impl GrpcInferenceService for KserveService {
continue;
}
Ok(request) => {
request_id = request.id.clone();
request.try_into().map_err(|e| {
Status::invalid_argument(format!("Failed to parse request: {}", e))
})?
request
}
};
let model = request.model_name.clone();
// [gluo TODO] refactor to reuse code, inference logic is largely the same
if state.is_tensor_model(&model) {
// Must keep track of 'request_id' which will be returned in corresponding response
let request_id = request.id.clone();
let tensor_request: NvCreateTensorRequest = request.try_into().map_err(|e| {
Status::invalid_argument(format!("Failed to parse request: {}", e))
})?;
let stream = tensor_response_stream(state.clone(), tensor_request, true).await?;
pin_mut!(stream);
while let Some(response) = stream.next().await {
match response.data {
Some(data) => {
let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| {
Status::invalid_argument(format!("Failed to parse response: {}", e))
})?;
if reply.infer_response.is_some() {
reply.infer_response.as_mut().unwrap().id = request_id.clone();
}
yield reply;
},
None => {
// Skip if no data is present, the response is for annotation
},
}
}
continue;
}
// Fallback handling by assuming the model is OpenAI Completions model
// Must keep track of 'request_id' which will be returned in corresponding response
let request_id = request.id.clone();
let mut completion_request: NvCreateCompletionRequest = request.try_into().map_err(|e| {
Status::invalid_argument(format!("Failed to parse request: {}", e))
})?;
// Apply template values if present
if let Some(template) = &template {
if completion_request.inner.model.is_empty() {
......@@ -309,7 +374,7 @@ impl GrpcInferenceService for KserveService {
"Failed to fold completions stream: {:?}",
e
);
Status::internal("Failed to fold completions stream")
Status::internal(format!("Failed to fold completions stream: {}", e))
})?;
let mut response: ModelStreamInferResponse = completion_response.try_into().map_err(|e| {
......@@ -332,11 +397,49 @@ impl GrpcInferenceService for KserveService {
&self,
request: Request<ModelMetadataRequest>,
) -> Result<Response<ModelMetadataResponse>, Status> {
let models = self.state.manager().list_completions_models();
let entries = self.state.manager().get_model_entries();
let request_model_name = &request.into_inner().name;
if let Some(model_name) = models.into_iter().find(|n| request_model_name == n) {
if let Some(entry) = entries
.into_iter()
.find(|entry| request_model_name == &entry.name)
{
if entry.model_type.supports_tensor() {
if let Some(config) = entry.runtime_config.as_ref()
&& let Some(tensor_model_config) = config.tensor_model_config.as_ref()
{
return Ok(Response::new(ModelMetadataResponse {
name: model_name,
name: tensor_model_config.name.clone(),
versions: vec!["1".to_string()],
platform: "dynamo".to_string(),
inputs: tensor_model_config
.inputs
.iter()
.map(|input| inference::model_metadata_response::TensorMetadata {
name: input.name.clone(),
datatype: input.data_type.to_string(),
shape: input.shape.clone(),
})
.collect(),
outputs: tensor_model_config
.outputs
.iter()
.map(
|output| inference::model_metadata_response::TensorMetadata {
name: output.name.clone(),
datatype: output.data_type.to_string(),
shape: output.shape.clone(),
},
)
.collect(),
}));
}
Err(Status::invalid_argument(format!(
"Model '{}' has type Tensor but no model config is provided",
request_model_name
)))?
} else if entry.model_type.supports_completions() {
return Ok(Response::new(ModelMetadataResponse {
name: entry.name,
versions: vec!["1".to_string()],
platform: "dynamo".to_string(),
inputs: vec![
......@@ -365,6 +468,7 @@ impl GrpcInferenceService for KserveService {
],
}));
}
}
Err(Status::not_found(format!(
"Model '{}' not found",
request_model_name
......@@ -375,11 +479,53 @@ impl GrpcInferenceService for KserveService {
&self,
request: Request<ModelConfigRequest>,
) -> Result<Response<ModelConfigResponse>, Status> {
let models = self.state.manager().list_completions_models();
let entries = self.state.manager().get_model_entries();
let request_model_name = &request.into_inner().name;
if let Some(model_name) = models.into_iter().find(|n| request_model_name == n) {
if let Some(entry) = entries
.into_iter()
.find(|entry| request_model_name == &entry.name)
{
if entry.model_type.supports_tensor() {
if let Some(config) = entry.runtime_config.as_ref()
&& let Some(tensor_model_config) = config.tensor_model_config.as_ref()
{
let model_config = ModelConfig {
name: tensor_model_config.name.clone(),
platform: "dynamo".to_string(),
backend: "dynamo".to_string(),
input: tensor_model_config
.inputs
.iter()
.map(|input| ModelInput {
name: input.name.clone(),
data_type: input.data_type.to_kserve(),
dims: input.shape.clone(),
..Default::default()
})
.collect(),
output: tensor_model_config
.outputs
.iter()
.map(|output| ModelOutput {
name: output.name.clone(),
data_type: output.data_type.to_kserve(),
dims: output.shape.clone(),
..Default::default()
})
.collect(),
..Default::default()
};
return Ok(Response::new(ModelConfigResponse {
config: Some(model_config.clone()),
}));
}
Err(Status::invalid_argument(format!(
"Model '{}' has type Tensor but no model config is provided",
request_model_name
)))?
} else if entry.model_type.supports_completions() {
let config = ModelConfig {
name: model_name,
name: entry.name,
platform: "dynamo".to_string(),
backend: "dynamo".to_string(),
input: vec![
......@@ -417,209 +563,10 @@ impl GrpcInferenceService for KserveService {
config: Some(config),
}));
}
}
Err(Status::not_found(format!(
"Model '{}' not found",
request_model_name
)))
}
}
impl TryFrom<ModelInferRequest> for NvCreateCompletionRequest {
type Error = Status;
fn try_from(request: ModelInferRequest) -> Result<Self, Self::Error> {
// Protocol requires if `raw_input_contents` is used to hold input data,
// it must be used for all inputs.
if !request.raw_input_contents.is_empty()
&& request.inputs.len() != request.raw_input_contents.len()
{
return Err(Status::invalid_argument(
"`raw_input_contents` must be used for all inputs",
));
}
// iterate through inputs
let mut text_input = None;
let mut stream = false;
for (idx, input) in request.inputs.iter().enumerate() {
match input.name.as_str() {
"text_input" => {
if input.datatype != "BYTES" {
return Err(Status::invalid_argument(format!(
"Expected 'text_input' to be of type BYTES for string input, got {:?}",
input.datatype
)));
}
if input.shape != vec![1] && input.shape != vec![1, 1] {
return Err(Status::invalid_argument(format!(
"Expected 'text_input' to have shape [1], got {:?}",
input.shape
)));
}
match &input.contents {
Some(content) => {
let bytes = &content.bytes_contents[0];
text_input = Some(String::from_utf8_lossy(bytes).to_string());
}
None => {
let raw_input =
request.raw_input_contents.get(idx).ok_or_else(|| {
Status::invalid_argument("Missing raw input for 'text_input'")
})?;
if raw_input.len() < 4 {
return Err(Status::invalid_argument(
"'text_input' raw input must be length-prefixed (>= 4 bytes)",
));
}
// We restrict the 'text_input' only contain one element, only need to
// parse the first element. Skip first four bytes that is used to store
// the length of the input.
text_input = Some(String::from_utf8_lossy(&raw_input[4..]).to_string());
}
}
}
"streaming" | "stream" => {
if input.datatype != "BOOL" {
return Err(Status::invalid_argument(format!(
"Expected '{}' to be of type BOOL, got {:?}",
input.name, input.datatype
)));
}
if input.shape != vec![1] {
return Err(Status::invalid_argument(format!(
"Expected 'stream' to have shape [1], got {:?}",
input.shape
)));
}
match &input.contents {
Some(content) => {
stream = content.bool_contents[0];
}
None => {
let raw_input =
request.raw_input_contents.get(idx).ok_or_else(|| {
Status::invalid_argument("Missing raw input for 'stream'")
})?;
if raw_input.is_empty() {
return Err(Status::invalid_argument(
"'stream' raw input must contain at least one byte",
));
}
stream = raw_input[0] != 0;
}
}
}
_ => {
return Err(Status::invalid_argument(format!(
"Invalid input name: {}, supported inputs are 'text_input', 'stream'",
input.name
)));
}
}
}
// return error if text_input is None
let text_input = match text_input {
Some(input) => input,
None => {
return Err(Status::invalid_argument(
"Missing required input: 'text_input'",
));
}
};
Ok(NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: request.model_name,
prompt: Prompt::String(text_input),
stream: Some(stream),
user: if request.id.is_empty() {
None
} else {
Some(request.id.clone())
},
..Default::default()
},
common: Default::default(),
nvext: None,
})
}
}
impl TryFrom<NvCreateCompletionResponse> for ModelInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
let mut outputs = vec![];
let mut text_output = vec![];
let mut finish_reason = vec![];
for choice in &response.inner.choices {
text_output.push(choice.text.clone());
if let Some(reason) = choice.finish_reason.as_ref() {
match reason {
CompletionFinishReason::Stop => {
finish_reason.push("stop".to_string());
}
CompletionFinishReason::Length => {
finish_reason.push("length".to_string());
}
CompletionFinishReason::ContentFilter => {
finish_reason.push("content_filter".to_string());
}
}
}
}
outputs.push(inference::model_infer_response::InferOutputTensor {
name: "text_output".to_string(),
datatype: "BYTES".to_string(),
shape: vec![text_output.len() as i64],
contents: Some(inference::InferTensorContents {
bytes_contents: text_output
.into_iter()
.map(|text| text.as_bytes().to_vec())
.collect(),
..Default::default()
}),
..Default::default()
});
outputs.push(inference::model_infer_response::InferOutputTensor {
name: "finish_reason".to_string(),
datatype: "BYTES".to_string(),
shape: vec![finish_reason.len() as i64],
contents: Some(inference::InferTensorContents {
bytes_contents: finish_reason
.into_iter()
.map(|text| text.as_bytes().to_vec())
.collect(),
..Default::default()
}),
..Default::default()
});
Ok(ModelInferResponse {
model_name: response.inner.model,
model_version: "1".to_string(),
id: response.inner.id,
outputs,
parameters: ::std::collections::HashMap::<String, InferParameter>::new(),
raw_output_contents: vec![],
})
}
}
impl TryFrom<NvCreateCompletionResponse> for ModelStreamInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
match ModelInferResponse::try_from(response) {
Ok(response) => Ok(ModelStreamInferResponse {
infer_response: Some(response),
..Default::default()
}),
Err(e) => Ok(ModelStreamInferResponse {
infer_response: None,
error_message: format!("Failed to convert response: {}", e),
}),
}
}
}
......@@ -15,12 +15,14 @@ use crate::protocols::openai::completions::{
use crate::types::Annotated;
use super::kserve;
use super::kserve::inference;
// [gluo NOTE] These are common utilities that should be shared between frontends
use crate::http::service::{
disconnect::{ConnectionHandle, create_connection_monitor},
metrics::{Endpoint, InflightGuard, process_response_and_observe_metrics},
};
use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt};
use tonic::Status;
......@@ -185,3 +187,205 @@ fn get_or_create_request_id(primary: Option<&str>) -> String {
let uuid = uuid::Uuid::new_v4();
uuid.to_string()
}
impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
type Error = Status;
fn try_from(request: inference::ModelInferRequest) -> Result<Self, Self::Error> {
// Protocol requires if `raw_input_contents` is used to hold input data,
// it must be used for all inputs.
if !request.raw_input_contents.is_empty()
&& request.inputs.len() != request.raw_input_contents.len()
{
return Err(Status::invalid_argument(
"`raw_input_contents` must be used for all inputs",
));
}
// iterate through inputs
let mut text_input = None;
let mut stream = false;
for (idx, input) in request.inputs.iter().enumerate() {
match input.name.as_str() {
"text_input" => {
if input.datatype != "BYTES" {
return Err(Status::invalid_argument(format!(
"Expected 'text_input' to be of type BYTES for string input, got {:?}",
input.datatype
)));
}
if input.shape != vec![1] && input.shape != vec![1, 1] {
return Err(Status::invalid_argument(format!(
"Expected 'text_input' to have shape [1], got {:?}",
input.shape
)));
}
match &input.contents {
Some(content) => {
let bytes = content.bytes_contents.first().ok_or_else(|| {
Status::invalid_argument(
"'text_input' must contain exactly one element",
)
})?;
text_input = Some(String::from_utf8_lossy(bytes).to_string());
}
None => {
let raw_input =
request.raw_input_contents.get(idx).ok_or_else(|| {
Status::invalid_argument("Missing raw input for 'text_input'")
})?;
if raw_input.len() < 4 {
return Err(Status::invalid_argument(
"'text_input' raw input must be length-prefixed (>= 4 bytes)",
));
}
// We restrict the 'text_input' only contain one element, only need to
// parse the first element. Skip first four bytes that is used to store
// the length of the input.
text_input = Some(String::from_utf8_lossy(&raw_input[4..]).to_string());
}
}
}
"streaming" | "stream" => {
if input.datatype != "BOOL" {
return Err(Status::invalid_argument(format!(
"Expected '{}' to be of type BOOL, got {:?}",
input.name, input.datatype
)));
}
if input.shape != vec![1] {
return Err(Status::invalid_argument(format!(
"Expected 'stream' to have shape [1], got {:?}",
input.shape
)));
}
match &input.contents {
Some(content) => {
stream = *content.bool_contents.first().ok_or_else(|| {
Status::invalid_argument(
"'stream' must contain exactly one element",
)
})?;
}
None => {
let raw_input =
request.raw_input_contents.get(idx).ok_or_else(|| {
Status::invalid_argument("Missing raw input for 'stream'")
})?;
if raw_input.is_empty() {
return Err(Status::invalid_argument(
"'stream' raw input must contain at least one byte",
));
}
stream = raw_input[0] != 0;
}
}
}
_ => {
return Err(Status::invalid_argument(format!(
"Invalid input name: {}, supported inputs are 'text_input', 'stream'",
input.name
)));
}
}
}
// return error if text_input is None
let text_input = match text_input {
Some(input) => input,
None => {
return Err(Status::invalid_argument(
"Missing required input: 'text_input'",
));
}
};
Ok(NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: request.model_name,
prompt: Prompt::String(text_input),
stream: Some(stream),
user: if request.id.is_empty() {
None
} else {
Some(request.id.clone())
},
..Default::default()
},
common: Default::default(),
nvext: None,
})
}
}
impl TryFrom<NvCreateCompletionResponse> for inference::ModelInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
let mut outputs = vec![];
let mut text_output = vec![];
let mut finish_reason = vec![];
for choice in &response.inner.choices {
text_output.push(choice.text.clone());
let reason_str = match choice.finish_reason.as_ref() {
Some(CompletionFinishReason::Stop) => "stop",
Some(CompletionFinishReason::Length) => "length",
Some(CompletionFinishReason::ContentFilter) => "content_filter",
None => "",
};
finish_reason.push(reason_str.to_string());
}
outputs.push(inference::model_infer_response::InferOutputTensor {
name: "text_output".to_string(),
datatype: "BYTES".to_string(),
shape: vec![text_output.len() as i64],
contents: Some(inference::InferTensorContents {
bytes_contents: text_output
.into_iter()
.map(|text| text.as_bytes().to_vec())
.collect(),
..Default::default()
}),
..Default::default()
});
outputs.push(inference::model_infer_response::InferOutputTensor {
name: "finish_reason".to_string(),
datatype: "BYTES".to_string(),
shape: vec![finish_reason.len() as i64],
contents: Some(inference::InferTensorContents {
bytes_contents: finish_reason
.into_iter()
.map(|text| text.as_bytes().to_vec())
.collect(),
..Default::default()
}),
..Default::default()
});
Ok(inference::ModelInferResponse {
model_name: response.inner.model,
model_version: "1".to_string(),
id: response.inner.id,
outputs,
parameters: ::std::collections::HashMap::<String, inference::InferParameter>::new(),
raw_output_contents: vec![],
})
}
}
impl TryFrom<NvCreateCompletionResponse> for inference::ModelStreamInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
match inference::ModelInferResponse::try_from(response) {
Ok(response) => Ok(inference::ModelStreamInferResponse {
infer_response: Some(response),
..Default::default()
}),
Err(e) => Ok(inference::ModelStreamInferResponse {
infer_response: None,
error_message: format!("Failed to convert response: {}", e),
}),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{
engine::AsyncEngineContext,
pipeline::{AsyncEngineContextProvider, Context},
protocols::annotated::AnnotationsProvider,
};
use futures::{Stream, StreamExt, stream};
use std::str::FromStr;
use std::sync::Arc;
use crate::types::Annotated;
use super::kserve;
// [gluo NOTE] These are common utilities that should be shared between frontends
use crate::http::service::{
disconnect::{ConnectionHandle, create_connection_monitor},
metrics::{Endpoint, ResponseMetricCollector},
};
use crate::{http::service::metrics::InflightGuard, preprocessor::LLMMetricAnnotation};
use crate::protocols::tensor;
use crate::protocols::tensor::{
NvCreateTensorRequest, NvCreateTensorResponse, Tensor, TensorMetadata,
};
use crate::grpc::service::kserve::inference;
use crate::grpc::service::kserve::inference::DataType;
use tonic::Status;
/// Dynamo Annotation for the request ID
pub const ANNOTATION_REQUEST_ID: &str = "request_id";
/// Tensor Request Handler
///
/// This method will handle the incoming request for model type tensor. The endpoint is a "source"
/// for an [`super::OpenAICompletionsStreamingEngine`] and will return a stream of
/// responses which will be forward to the client.
///
/// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For
/// non-streaming requests, we will fold the stream into a single response as part of this handler.
pub async fn tensor_response_stream(
state: Arc<kserve::State>,
request: NvCreateTensorRequest,
streaming: bool,
) -> Result<impl Stream<Item = Annotated<NvCreateTensorResponse>>, Status> {
// create the context for the request
let request_id = get_or_create_request_id(request.id.as_deref());
let request = Context::with_id(request, request_id.clone());
let context = request.context();
// [gluo TODO] revisit metrics to properly expose it
// create the connection handles
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
let model = &request.model;
// todo - error handling should be more robust
let engine = state
.manager()
.get_tensor_engine(model)
.map_err(|_| Status::not_found("model not found"))?;
let inflight_guard =
state
.metrics_clone()
.create_inflight_guard(model, Endpoint::Tensor, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(model);
// prepare to process any annotations
let annotations = request.annotations();
// issue the generate call on the engine
let stream = engine.generate(request).await.map_err(|e| {
Status::internal(format!("Failed to generate tensor response stream: {}", e))
})?;
// capture the context to cancel the stream if the client disconnects
let ctx = stream.context();
// prepare any requested annotations
let annotations = annotations.map_or(Vec::new(), |annotations| {
annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::<NvCreateTensorResponse>::from_annotation(
ANNOTATION_REQUEST_ID,
&request_id,
)
.ok()
} else {
None
}
})
.collect::<Vec<_>>()
});
// apply any annotations to the front of the stream
let stream = stream::iter(annotations).chain(stream);
// Tap on the stream to collect response metrics
let stream = stream.inspect(move |response| {
process_metrics_only(response, &mut response_collector);
});
let stream = grpc_monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
// if we got here, then we will return a response and the potentially long running task has completed successfully
// without need to be cancelled.
connection_handle.disarm();
Ok(stream)
}
/// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation.
/// This is gRPC variant of `monitor_for_disconnects` as that implementation has SSE specific handling.
/// Should decouple and reuse `monitor_for_disconnects`
///
/// Uses `tokio::select!` to choose between receiving responses from the source stream or detecting when
/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
/// naturally, we mark the request as successful and send the final `[DONE]` event.
pub fn grpc_monitor_for_disconnects<T>(
stream: impl Stream<Item = Annotated<T>>,
context: Arc<dyn AsyncEngineContext>,
mut inflight_guard: InflightGuard,
mut stream_handle: ConnectionHandle,
) -> impl Stream<Item = Annotated<T>> {
stream_handle.arm();
async_stream::stream! {
tokio::pin!(stream);
loop {
tokio::select! {
event = stream.next() => {
match event {
Some(response) => {
yield response;
}
None => {
// Stream ended normally
inflight_guard.mark_ok();
stream_handle.disarm();
break;
}
}
}
// todo - test request cancellation with kserve frontend and tensor-based models
_ = context.stopped() => {
tracing::trace!("Context stopped; breaking stream");
break;
}
}
}
}
}
fn process_metrics_only<T>(
annotated: &Annotated<T>,
response_collector: &mut ResponseMetricCollector,
) {
// update metrics
if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(annotated) {
response_collector.observe_current_osl(metrics.output_tokens);
response_collector.observe_response(metrics.input_tokens, metrics.chunk_tokens);
}
}
/// Get the request ID from a primary source, or lastly create a new one if not present
fn get_or_create_request_id(primary: Option<&str>) -> String {
// Try to get the request ID from the primary source
if let Some(primary) = primary
&& let Ok(uuid) = uuid::Uuid::parse_str(primary)
{
return uuid.to_string();
}
// Try to parse the request ID as a UUID, or generate a new one if missing/invalid
let uuid = uuid::Uuid::new_v4();
uuid.to_string()
}
impl TryFrom<inference::ModelInferRequest> for NvCreateTensorRequest {
type Error = Status;
fn try_from(request: inference::ModelInferRequest) -> Result<Self, Self::Error> {
// Protocol requires if `raw_input_contents` is used to hold input data,
// it must be used for all inputs.
if !request.raw_input_contents.is_empty()
&& request.inputs.len() != request.raw_input_contents.len()
{
return Err(Status::invalid_argument(
"`raw_input_contents` must be used for all inputs",
));
}
let mut tensor_request = NvCreateTensorRequest {
id: if !request.id.is_empty() {
Some(request.id.clone())
} else {
None
},
model: request.model_name.clone(),
tensors: Vec::new(),
nvext: None,
};
// iterate through inputs
for (idx, input) in request.inputs.into_iter().enumerate() {
let mut tensor = Tensor {
metadata: TensorMetadata {
name: input.name.clone(),
data_type: tensor::DataType::from_str(&input.datatype)
.map_err(|err| Status::invalid_argument(err.to_string()))?,
shape: input.shape.clone(),
},
// Placeholder, will be filled below
data: tensor::FlattenTensor::Bool(Vec::new()),
};
match &input.contents {
// If contents is provided in InferInputTensor
Some(contents) => {
tensor.set_data_from_tensor_contents(contents);
}
// If not in InferInputTensor, contents is provided in raw_input_contents
None => {
tensor.set_data_from_raw_contents(&request.raw_input_contents[idx])?;
}
}
tensor_request.tensors.push(tensor);
}
Ok(tensor_request)
}
}
impl tensor::Tensor {
fn set_data_from_tensor_contents(&mut self, contents: &inference::InferTensorContents) {
self.data = match self.metadata.data_type {
tensor::DataType::Bool => tensor::FlattenTensor::Bool(contents.bool_contents.clone()),
tensor::DataType::Uint8 => tensor::FlattenTensor::Uint8(
contents.uint_contents.iter().map(|&x| x as u8).collect(),
),
tensor::DataType::Uint16 => tensor::FlattenTensor::Uint16(
contents.uint_contents.iter().map(|&x| x as u16).collect(),
),
tensor::DataType::Uint32 => {
tensor::FlattenTensor::Uint32(contents.uint_contents.clone())
}
tensor::DataType::Uint64 => {
tensor::FlattenTensor::Uint64(contents.uint64_contents.clone())
}
tensor::DataType::Int8 => tensor::FlattenTensor::Int8(
contents.int_contents.iter().map(|&x| x as i8).collect(),
),
tensor::DataType::Int16 => tensor::FlattenTensor::Int16(
contents.int_contents.iter().map(|&x| x as i16).collect(),
),
tensor::DataType::Int32 => tensor::FlattenTensor::Int32(contents.int_contents.clone()),
tensor::DataType::Int64 => {
tensor::FlattenTensor::Int64(contents.int64_contents.clone())
}
tensor::DataType::Float32 => {
tensor::FlattenTensor::Float32(contents.fp32_contents.clone())
}
tensor::DataType::Float64 => {
tensor::FlattenTensor::Float64(contents.fp64_contents.clone())
}
tensor::DataType::Bytes => {
tensor::FlattenTensor::Bytes(contents.bytes_contents.clone())
}
}
}
#[allow(clippy::result_large_err)]
fn set_data_from_raw_contents(&mut self, raw_input: &[u8]) -> Result<(), Status> {
let element_count = self.metadata.shape.iter().try_fold(1usize, |acc, &d| {
if d < 0 {
Err(Status::invalid_argument(format!(
"Shape contains negative dimension: {}",
d
)))
} else {
acc.checked_mul(d as usize).ok_or_else(|| {
Status::invalid_argument("Overflow occurred while calculating element count")
})
}
})?;
let data_size = self.metadata.data_type.size();
// For BYTES type, we need to parse length-prefixed strings and properly slice them
// into bytes of array, and early return
if data_size == 0 {
self.data = self.raw_input_to_bytes_tensor(element_count, raw_input)?;
return Ok(());
}
// Control reaches here on non-bytes types
// validate raw input length before conversion
if !raw_input.len().is_multiple_of(data_size) {
return Err(Status::invalid_argument(format!(
"Raw input length must be a multiple of {}",
data_size
)));
} else if raw_input.len() / data_size != element_count {
return Err(Status::invalid_argument(format!(
"Raw input element count for '{}' does not match expected size, expected {} elements, got {} elements",
self.metadata.name,
element_count,
raw_input.len() / data_size
)));
}
self.data = self.raw_input_to_typed_tensor(raw_input)?;
Ok(())
}
#[allow(clippy::result_large_err)]
fn raw_input_to_bytes_tensor(
&self,
element_count: usize,
raw_input: &[u8],
) -> Result<tensor::FlattenTensor, Status> {
// element is not fixed size for bytes type, so the raw input has
// length-prefixed bytes for each element.
let mut bytes_contents = vec![];
let mut offset = 0;
while offset + 4 <= raw_input.len() {
let len =
u32::from_le_bytes(raw_input[offset..offset + 4].try_into().unwrap()) as usize;
offset += 4;
if offset + len > raw_input.len() {
return Err(Status::invalid_argument(format!(
"Invalid length-prefixed BYTES input for '{}', length exceeds raw input size",
self.metadata.name
)));
}
bytes_contents.push(raw_input[offset..offset + len].to_vec());
offset += len;
}
if offset != raw_input.len() {
return Err(Status::invalid_argument(format!(
"Invalid length-prefixed BYTES input for '{}', extra bytes at the end",
self.metadata.name
)));
}
if element_count != bytes_contents.len() {
return Err(Status::invalid_argument(format!(
"Raw input element count for '{}' does not match expected size, expected {} elements, got {} elements",
self.metadata.name,
element_count,
bytes_contents.len()
)));
}
Ok(tensor::FlattenTensor::Bytes(bytes_contents))
}
#[allow(clippy::result_large_err)]
fn raw_input_to_typed_tensor(&self, raw_input: &[u8]) -> Result<tensor::FlattenTensor, Status> {
// In Rust, we can not "reinterpret cast" a Vec<u8> to Vec<T> directly
// as Vec require the pointer to be aligned with the type T, which can not
// be guaranteed from Vec<u8>. We will have to reconstruct the Vec<T> element
// by element which results in data copy.
// Here we assume little endianess for all types as the KServe protocol doesn't
// specify the endianness while it should have.
match self.metadata.data_type {
tensor::DataType::Bool => Ok(tensor::FlattenTensor::Bool(
raw_input.iter().map(|&b| b != 0).collect(),
)),
tensor::DataType::Uint8 => Ok(tensor::FlattenTensor::Uint8(
raw_input.chunks_exact(1).map(|chunk| chunk[0]).collect(),
)),
tensor::DataType::Uint16 => Ok(tensor::FlattenTensor::Uint16(
raw_input
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect(),
)),
tensor::DataType::Uint32 => Ok(tensor::FlattenTensor::Uint32(
raw_input
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect(),
)),
tensor::DataType::Uint64 => Ok(tensor::FlattenTensor::Uint64(
raw_input
.chunks_exact(8)
.map(|chunk| {
u64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect(),
)),
tensor::DataType::Int8 => Ok(tensor::FlattenTensor::Int8(
raw_input
.chunks_exact(1)
.map(|chunk| chunk[0] as i8)
.collect(),
)),
tensor::DataType::Int16 => Ok(tensor::FlattenTensor::Int16(
raw_input
.chunks_exact(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]))
.collect(),
)),
tensor::DataType::Int32 => Ok(tensor::FlattenTensor::Int32(
raw_input
.chunks_exact(4)
.map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect(),
)),
tensor::DataType::Int64 => Ok(tensor::FlattenTensor::Int64(
raw_input
.chunks_exact(8)
.map(|chunk| {
i64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect(),
)),
tensor::DataType::Float32 => Ok(tensor::FlattenTensor::Float32(
raw_input
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect(),
)),
tensor::DataType::Float64 => Ok(tensor::FlattenTensor::Float64(
raw_input
.chunks_exact(8)
.map(|chunk| {
f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
])
})
.collect(),
)),
tensor::DataType::Bytes => Err(Status::internal(format!(
"Unexpected BYTES type in non-bytes branch for input '{}'",
self.metadata.name
))),
}
}
}
impl TryFrom<NvCreateTensorResponse> for inference::ModelInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateTensorResponse) -> Result<Self, Self::Error> {
let mut infer_response = inference::ModelInferResponse {
model_name: response.model,
model_version: "1".to_string(),
id: response.id.unwrap_or_default(),
outputs: vec![],
parameters: ::std::collections::HashMap::<String, inference::InferParameter>::new(),
raw_output_contents: vec![],
};
for tensor in &response.tensors {
infer_response
.outputs
.push(inference::model_infer_response::InferOutputTensor {
name: tensor.metadata.name.clone(),
datatype: tensor.metadata.data_type.to_string(),
shape: tensor.metadata.shape.clone(),
contents: match &tensor.data {
tensor::FlattenTensor::Bool(data) => Some(inference::InferTensorContents {
bool_contents: data.clone(),
..Default::default()
}),
tensor::FlattenTensor::Uint8(data) => {
Some(inference::InferTensorContents {
uint_contents: data.iter().map(|&x| x as u32).collect(),
..Default::default()
})
}
tensor::FlattenTensor::Uint16(data) => {
Some(inference::InferTensorContents {
uint_contents: data.iter().map(|&x| x as u32).collect(),
..Default::default()
})
}
tensor::FlattenTensor::Uint32(data) => {
Some(inference::InferTensorContents {
uint_contents: data.clone(),
..Default::default()
})
}
tensor::FlattenTensor::Uint64(data) => {
Some(inference::InferTensorContents {
uint64_contents: data.clone(),
..Default::default()
})
}
tensor::FlattenTensor::Int8(data) => Some(inference::InferTensorContents {
int_contents: data.iter().map(|&x| x as i32).collect(),
..Default::default()
}),
tensor::FlattenTensor::Int16(data) => {
Some(inference::InferTensorContents {
int_contents: data.iter().map(|&x| x as i32).collect(),
..Default::default()
})
}
tensor::FlattenTensor::Int32(data) => {
Some(inference::InferTensorContents {
int_contents: data.clone(),
..Default::default()
})
}
tensor::FlattenTensor::Int64(data) => {
Some(inference::InferTensorContents {
int64_contents: data.clone(),
..Default::default()
})
}
tensor::FlattenTensor::Float32(data) => {
Some(inference::InferTensorContents {
fp32_contents: data.clone(),
..Default::default()
})
}
tensor::FlattenTensor::Float64(data) => {
Some(inference::InferTensorContents {
fp64_contents: data.clone(),
..Default::default()
})
}
tensor::FlattenTensor::Bytes(data) => {
Some(inference::InferTensorContents {
bytes_contents: data.clone(),
..Default::default()
})
}
},
..Default::default()
});
}
Ok(infer_response)
}
}
impl TryFrom<NvCreateTensorResponse> for inference::ModelStreamInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateTensorResponse) -> Result<Self, Self::Error> {
match inference::ModelInferResponse::try_from(response) {
Ok(response) => Ok(inference::ModelStreamInferResponse {
infer_response: Some(response),
..Default::default()
}),
Err(e) => Ok(inference::ModelStreamInferResponse {
infer_response: None,
error_message: format!("Failed to convert response: {}", e),
}),
}
}
}
impl tensor::DataType {
pub fn to_kserve(&self) -> i32 {
match *self {
tensor::DataType::Bool => DataType::TypeBool as i32,
tensor::DataType::Uint8 => DataType::TypeUint8 as i32,
tensor::DataType::Uint16 => DataType::TypeUint16 as i32,
tensor::DataType::Uint32 => DataType::TypeUint32 as i32,
tensor::DataType::Uint64 => DataType::TypeUint64 as i32,
tensor::DataType::Int8 => DataType::TypeInt8 as i32,
tensor::DataType::Int16 => DataType::TypeInt16 as i32,
tensor::DataType::Int32 => DataType::TypeInt32 as i32,
tensor::DataType::Int64 => DataType::TypeInt64 as i32,
tensor::DataType::Float32 => DataType::TypeFp32 as i32,
tensor::DataType::Float64 => DataType::TypeFp64 as i32,
tensor::DataType::Bytes => DataType::TypeString as i32,
}
}
}
impl std::fmt::Display for tensor::DataType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
tensor::DataType::Bool => write!(f, "BOOL"),
tensor::DataType::Uint8 => write!(f, "UINT8"),
tensor::DataType::Uint16 => write!(f, "UINT16"),
tensor::DataType::Uint32 => write!(f, "UINT32"),
tensor::DataType::Uint64 => write!(f, "UINT64"),
tensor::DataType::Int8 => write!(f, "INT8"),
tensor::DataType::Int16 => write!(f, "INT16"),
tensor::DataType::Int32 => write!(f, "INT32"),
tensor::DataType::Int64 => write!(f, "INT64"),
tensor::DataType::Float32 => write!(f, "FP32"),
tensor::DataType::Float64 => write!(f, "FP64"),
tensor::DataType::Bytes => write!(f, "BYTES"),
}
}
}
impl FromStr for tensor::DataType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"BOOL" => Ok(tensor::DataType::Bool),
"UINT8" => Ok(tensor::DataType::Uint8),
"UINT16" => Ok(tensor::DataType::Uint16),
"UINT32" => Ok(tensor::DataType::Uint32),
"UINT64" => Ok(tensor::DataType::Uint64),
"INT8" => Ok(tensor::DataType::Int8),
"INT16" => Ok(tensor::DataType::Int16),
"INT32" => Ok(tensor::DataType::Int32),
"INT64" => Ok(tensor::DataType::Int64),
"FP32" => Ok(tensor::DataType::Float32),
"FP64" => Ok(tensor::DataType::Float64),
"BYTES" => Ok(tensor::DataType::Bytes),
_ => Err(anyhow::anyhow!("Invalid data type")),
}
}
}
......@@ -73,6 +73,9 @@ pub enum Endpoint {
/// OAI Responses
Responses,
/// Tensor
Tensor,
}
/// Metrics for the HTTP service
......@@ -456,6 +459,7 @@ impl std::fmt::Display for Endpoint {
Endpoint::ChatCompletions => write!(f, "chat_completions"),
Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Responses => write!(f, "responses"),
Endpoint::Tensor => write!(f, "tensor"),
}
}
}
......@@ -467,6 +471,7 @@ impl Endpoint {
Endpoint::ChatCompletions => "chat_completions",
Endpoint::Embeddings => "embeddings",
Endpoint::Responses => "responses",
Endpoint::Tensor => "tensor",
}
}
}
......
......@@ -5,6 +5,8 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::protocols::tensor;
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelRuntimeConfig {
pub total_kv_blocks: Option<u64>,
......@@ -20,6 +22,16 @@ pub struct ModelRuntimeConfig {
/// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>,
// Provide tensor model config in the case where the model type is Tensor.
// Currently use JSON object for convinence, the programmatic way is to
// define the model config struct as part of the tensor protocol and
// import it here.
// [gluo TODO] switch to ModelConfig if desired and workout a way to
// prepare it in a convinent way, the protobuf library used by tonic
// doesn't provide JSON parsing.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tensor_model_config: Option<tensor::TensorModelConfig>,
}
impl ModelRuntimeConfig {
......
......@@ -15,6 +15,7 @@ bitflags! {
/// - `ModelType::Chat`
/// - `ModelType::Completions`
/// - `ModelType::Embedding`
/// - `ModelType::TensorBased`
///
/// For example, a model that supports both chat and completions can be
/// expressed as:
......@@ -34,6 +35,7 @@ bitflags! {
const Chat = 1 << 0;
const Completions = 1 << 1;
const Embedding = 1 << 2;
const TensorBased = 1 << 3;
}
}
......@@ -51,6 +53,9 @@ impl ModelType {
pub fn supports_embedding(&self) -> bool {
self.contains(ModelType::Embedding)
}
pub fn supports_tensor(&self) -> bool {
self.contains(ModelType::TensorBased)
}
pub fn as_vec(&self) -> Vec<&'static str> {
let mut result = Vec::new();
......@@ -63,6 +68,9 @@ impl ModelType {
if self.supports_embedding() {
result.push("embedding");
}
if self.supports_tensor() {
result.push("tensor");
}
result
}
......@@ -79,6 +87,9 @@ impl ModelType {
if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
}
// [gluo NOTE] ModelType::Tensor doesn't map to any endpoint type,
// current use of endpoint type is LLM specific and so does the HTTP
// server that uses it.
endpoint_types
}
}
......@@ -95,6 +106,8 @@ pub enum ModelInput {
Text,
/// Pre-processed input
Tokens,
/// Tensor input
Tensor,
}
impl ModelInput {
......@@ -102,6 +115,7 @@ impl ModelInput {
match self {
Self::Text => "text",
Self::Tokens => "tokens",
Self::Tensor => "tensor",
}
}
}
......@@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize};
pub mod codec;
pub mod common;
pub mod openai;
pub mod tensor;
/// The token ID type
pub type TokenIdType = u32;
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::protocols::Annotated;
use anyhow::Result;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use validator::Validate;
// [gluo TODO] whether it makes sense to have aggregator for tensor..
// we could if considering aggregation to be stacking the tensors by adding
// one more dimension. i.e. stream of [2, 2] tensors to be aggregated to
// [-1, 2, 2]. Will decide it later and currently do not allow aggregation.
// mod aggregator;
// pub use aggregator::DeltaAggregator;
// [gluo TODO] nvext is LLM specific, we really only use the annotation field
pub use super::openai::nvext::{NvExt, NvExtProvider};
#[derive(Debug, Serialize, Clone, Eq, PartialEq, Deserialize)]
pub enum DataType {
Bool,
Uint8,
Uint16,
Uint32,
Uint64,
Int8,
Int16,
Int32,
Int64,
Float32,
Float64,
Bytes,
}
impl DataType {
pub fn size(&self) -> usize {
match self {
DataType::Bool => size_of::<bool>(),
DataType::Uint8 => size_of::<u8>(),
DataType::Uint16 => size_of::<u16>(),
DataType::Uint32 => size_of::<u32>(),
DataType::Uint64 => size_of::<u64>(),
DataType::Int8 => size_of::<i8>(),
DataType::Int16 => size_of::<i16>(),
DataType::Int32 => size_of::<i32>(),
DataType::Int64 => size_of::<i64>(),
DataType::Float32 => size_of::<f32>(),
DataType::Float64 => size_of::<f64>(),
DataType::Bytes => 0, // variable length, return 0 as indicator
}
}
}
#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
// Self-describing encoding removes ambiguity between signed/unsigned and width variants.
#[serde(tag = "data_type", content = "values")]
pub enum FlattenTensor {
Bool(Vec<bool>),
// [gluo NOTE] f16, and bf16 is not stably supported
Uint8(Vec<u8>),
Uint16(Vec<u16>),
Uint32(Vec<u32>),
Uint64(Vec<u64>),
Int8(Vec<i8>),
Int16(Vec<i16>),
Int32(Vec<i32>),
Int64(Vec<i64>),
Float32(Vec<f32>),
Float64(Vec<f64>),
// Typically use to store string data, but really it can store
// arbitrary data such as serialized handles for custom worker behavior.
Bytes(Vec<Vec<u8>>),
}
#[allow(clippy::len_without_is_empty)]
impl FlattenTensor {
pub fn len(&self) -> usize {
match self {
Self::Bool(v) => v.len(),
Self::Uint8(v) => v.len(),
Self::Uint16(v) => v.len(),
Self::Uint32(v) => v.len(),
Self::Uint64(v) => v.len(),
Self::Int8(v) => v.len(),
Self::Int16(v) => v.len(),
Self::Int32(v) => v.len(),
Self::Int64(v) => v.len(),
Self::Float32(v) => v.len(),
Self::Float64(v) => v.len(),
Self::Bytes(v) => v.len(),
}
}
pub fn data_type(&self) -> DataType {
match self {
Self::Bool(_) => DataType::Bool,
Self::Uint8(_) => DataType::Uint8,
Self::Uint16(_) => DataType::Uint16,
Self::Uint32(_) => DataType::Uint32,
Self::Uint64(_) => DataType::Uint64,
Self::Int8(_) => DataType::Int8,
Self::Int16(_) => DataType::Int16,
Self::Int32(_) => DataType::Int32,
Self::Int64(_) => DataType::Int64,
Self::Float32(_) => DataType::Float32,
Self::Float64(_) => DataType::Float64,
Self::Bytes(_) => DataType::Bytes,
}
}
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)]
pub struct TensorMetadata {
pub name: String,
pub data_type: DataType,
pub shape: Vec<i64>,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)]
pub struct TensorModelConfig {
pub name: String,
pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Tensor {
pub metadata: TensorMetadata,
pub data: FlattenTensor,
}
impl validator::Validate for Tensor {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
use validator::{ValidationError, ValidationErrors};
let mut errs = ValidationErrors::new();
// dtype must match
if self.metadata.data_type != self.data.data_type() {
let mut e = ValidationError::new("dtype_mismatch");
e.message = Some("metadata.data_type does not match data variant".into());
errs.add("data_type", e);
}
let mut product: usize = 1;
for &d in &self.metadata.shape {
if d < 0 {
let mut e = ValidationError::new("negative_dim");
e.message = Some("only -1 is allowed as a wildcard dimension".into());
errs.add("shape", e);
break;
}
product = product.saturating_mul(d as usize);
}
// bytes payloads may be variable-length per item; enforce outer count only
let expect_count = self.data.len();
if product != expect_count {
let mut e = ValidationError::new("element_count_mismatch");
e.message = Some(
format!(
"shape implies {} elements but data has {}",
product, expect_count
)
.into(),
);
errs.add("shape", e);
}
if errs.is_empty() { Ok(()) } else { Err(errs) }
}
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateTensorRequest {
/// ID of the request
pub id: Option<String>,
/// ID of the model to use.
pub model: String,
/// Input tensors.
pub tensors: Vec<Tensor>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
/// A response structure for unary chat completion responses, embedding OpenAI's
/// `CreateChatCompletionResponse`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateTensorResponse {
/// ID of the corresponding request.
pub id: Option<String>,
/// ID of the model.
pub model: String,
/// Output tensors.
pub tensors: Vec<Tensor>,
}
/// Implements `NvExtProvider` for `NvCreateTensorRequest`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateTensorRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn raw_prompt(&self) -> Option<String> {
// Not really apply here.
None
}
}
/// Implements `AnnotationsProvider` for `NvCreateTensorRequest`,
/// enabling retrieval and management of request annotations.
impl AnnotationsProvider for NvCreateTensorRequest {
/// Retrieves the list of annotations from `NvExt`, if present.
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
/// Checks whether a specific annotation exists in the request.
///
/// # Arguments
/// * `annotation` - A string slice representing the annotation to check.
///
/// # Returns
/// `true` if the annotation exists, `false` otherwise.
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
pub struct DeltaAggregator {
response: Option<NvCreateTensorResponse>,
error: Option<String>,
}
impl NvCreateTensorResponse {
pub async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateTensorResponse>>,
) -> Result<NvCreateTensorResponse> {
let aggregator = stream
.fold(
DeltaAggregator {
response: None,
error: None,
},
|mut aggregator, delta| async move {
let delta = match delta.ok() {
Ok(delta) => delta,
Err(error) => {
if aggregator.error.is_none() {
aggregator.error = Some(error);
}
return aggregator;
}
};
match delta.data {
Some(resp) => {
if aggregator.response.is_none() {
aggregator.response = Some(resp);
} else if aggregator.error.is_none() {
aggregator.error =
Some("Multiple responses in non-streaming mode".to_string());
}
}
None => {
// Ignore metadata-only deltas in non-streaming mode.
}
}
aggregator
},
)
.await;
if let Some(error) = aggregator.error {
Err(anyhow::anyhow!(error))
} else if let Some(response) = aggregator.response {
Ok(response)
} else {
Err(anyhow::anyhow!("No response received"))
}
}
}
......@@ -60,3 +60,21 @@ pub mod openai {
ServerStreamingEngine<NvCreateEmbeddingRequest, Annotated<NvCreateEmbeddingResponse>>;
}
}
pub mod generic {
use super::*;
use dynamo_runtime::pipeline::{ServerStreamingEngine, UnaryEngine};
pub mod tensor {
use super::*;
pub use protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
/// A [`UnaryEngine`] implementation for the generic Tensor API
pub type TensorUnaryEngine = UnaryEngine<NvCreateTensorRequest, NvCreateTensorResponse>;
/// A [`ServerStreamingEngine`] implementation for the generic Tensor API
pub type TensorStreamingEngine =
ServerStreamingEngine<NvCreateTensorRequest, Annotated<NvCreateTensorResponse>>;
}
}
......@@ -212,6 +212,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu
Endpoint::ChatCompletions => 1,
Endpoint::Embeddings => todo!(),
Endpoint::Responses => todo!(),
Endpoint::Tensor => todo!(),
};
let request_type = match request_type {
......
......@@ -6,6 +6,11 @@ pub mod kserve_test {
pub mod inference {
tonic::include_proto!("inference");
}
use dynamo_llm::discovery::ModelEntry;
use dynamo_llm::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_llm::model_type::{ModelInput, ModelType};
use dynamo_llm::protocols::tensor;
use dynamo_runtime::protocols::EndpointId;
use inference::grpc_inference_service_client::GrpcInferenceServiceClient;
use inference::{
DataType, ModelConfigRequest, ModelInferRequest, ModelInferResponse, ModelMetadataRequest,
......@@ -22,6 +27,7 @@ pub mod kserve_test {
},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
};
use dynamo_runtime::{
CancellationToken,
......@@ -181,6 +187,52 @@ pub mod kserve_test {
}
}
struct TensorEngine {}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateTensorRequest>,
ManyOut<Annotated<NvCreateTensorResponse>>,
Error,
> for TensorEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateTensorRequest>,
) -> Result<ManyOut<Annotated<NvCreateTensorResponse>>, Error> {
// Echo input tensor in response, additionally check if there is input tensor
// name "repeat", if so, send the same response as many time as the value of the tensor
let (request, context) = request.transfer(());
let ctx = context.context();
let repeat_count = request
.tensors
.iter()
.find_map(|t| {
if t.metadata.name == "repeat"
&& let tensor::FlattenTensor::Int32(data) = &t.data
&& !data.is_empty()
{
return Some(data[0]);
}
None
})
.unwrap_or(1);
let stream = async_stream::stream! {
for _ in 0..repeat_count {
yield Annotated::from_data(NvCreateTensorResponse {
id: request.id.clone(),
model: request.model.clone(),
tensors: request.tensors.clone(),
});
}
};
Ok(ResponseStream::new(Box::pin(stream), ctx))
}
}
/// Wait for the HTTP service to be ready by checking its health endpoint
async fn get_ready_client(port: u16, timeout_secs: u64) -> GrpcInferenceServiceClient<Channel> {
let start = tokio::time::Instant::now();
......@@ -232,15 +284,58 @@ pub mod kserve_test {
manager
.add_completions_model("split", split.clone())
.unwrap();
manager.save_model_entry(
"split",
ModelEntry {
name: "split".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "split".to_string(),
},
model_type: ModelType::Completions,
model_input: ModelInput::Text,
runtime_config: None,
},
);
manager
.add_chat_completions_model("failure", failure.clone())
.unwrap();
manager
.add_completions_model("failure", failure.clone())
.unwrap();
manager.save_model_entry(
"failure",
ModelEntry {
name: "failure".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "failure".to_string(),
},
model_type: ModelType::Completions | ModelType::Chat,
model_input: ModelInput::Text,
runtime_config: None,
},
);
manager
.add_completions_model("long_running", long_running.clone())
.unwrap();
manager.save_model_entry(
"long_running",
ModelEntry {
name: "long_running".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "long_running".to_string(),
},
model_type: ModelType::Completions,
model_input: ModelInput::Text,
runtime_config: None,
},
);
(service, split, failure, long_running)
}
......@@ -276,6 +371,7 @@ pub mod kserve_test {
InferCancellation = 8992,
StreamInferCancellation = 8993,
ModelInfo = 8994,
TensorModel = 8995,
}
#[rstest]
......@@ -474,8 +570,8 @@ pub mod kserve_test {
);
assert_eq!(
output.shape,
vec![0],
"Expected 'finish_reason' to have shape [0]"
vec![1],
"Expected 'finish_reason' to have shape [1]"
);
}
_ => panic!("Unexpected output name: {}", output.name),
......@@ -632,8 +728,8 @@ pub mod kserve_test {
);
assert_eq!(
output.shape,
vec![0],
"Expected 'finish_reason' to have shape [0]"
vec![1],
"Expected 'finish_reason' to have shape [1]"
);
}
_ => panic!("Unexpected output name: {}", output.name),
......@@ -724,8 +820,8 @@ pub mod kserve_test {
);
assert_eq!(
output.shape,
vec![0],
"Expected 'finish_reason' to have shape [0]"
vec![1],
"Expected 'finish_reason' to have shape [1]"
);
}
_ => panic!("Unexpected output name: {}", output.name),
......@@ -1052,4 +1148,326 @@ pub mod kserve_test {
}
}
}
#[rstest]
#[tokio::test]
async fn test_tensor_infer(
#[with(TestPort::TensorModel as u16)] service_with_engines: (
KserveService,
Arc<SplitEngine>,
Arc<AlwaysFailEngine>,
Arc<LongRunningEngine>,
),
text_input: inference::model_infer_request::InferInputTensor,
) {
// add tensor model
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", tensor.clone())
.unwrap();
// start server
let _running = RunningService::spawn(service_with_engines.0.clone());
let mut client = get_ready_client(TestPort::TensorModel as u16, 5).await;
let request = tonic::Request::new(ModelMetadataRequest {
name: "tensor".into(),
version: "".into(),
});
// Failure, model registered as Tensor but does not provide model config (in runtime config)
let entry = ModelEntry {
name: "tensor".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "endpoint".to_string(),
},
model_type: ModelType::TensorBased,
model_input: ModelInput::Tensor,
runtime_config: None,
};
service_with_engines
.0
.model_manager()
.save_model_entry("key", entry);
let response = client.model_metadata(request).await;
assert!(response.is_err());
let err = response.unwrap_err();
assert_eq!(
err.code(),
tonic::Code::InvalidArgument,
"Expected InvalidArgument error for unregistered model, get {}",
err
);
assert!(
err.message()
.contains("has type Tensor but no model config is provided"),
"Expected error message to contain 'has type Tensor but no model config is provided', got: {}",
err.message()
);
let request = tonic::Request::new(ModelConfigRequest {
name: "tensor".into(),
version: "".into(),
});
let response = client.model_config(request).await;
assert!(response.is_err());
let err = response.unwrap_err();
assert_eq!(
err.code(),
tonic::Code::InvalidArgument,
"Expected InvalidArgument error for unregistered model, get {}",
err
);
assert!(
err.message()
.contains("has type Tensor but no model config is provided"),
"Expected error message to contain 'has type Tensor but no model config is provided', got: {}",
err.message()
);
// Change model entry to have model config
service_with_engines
.0
.model_manager()
.remove_model_entry("key");
let entry = ModelEntry {
name: "tensor".to_string(),
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "endpoint".to_string(),
},
model_type: ModelType::TensorBased,
model_input: ModelInput::Tensor,
runtime_config: Some(ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
name: "tensor".to_string(),
inputs: vec![tensor::TensorMetadata {
name: "input".to_string(),
data_type: tensor::DataType::Bytes,
shape: vec![1],
}],
outputs: vec![tensor::TensorMetadata {
name: "output".to_string(),
data_type: tensor::DataType::Bool,
shape: vec![-1],
}],
}),
..Default::default()
}),
};
service_with_engines
.0
.model_manager()
.save_model_entry("key", entry);
// Success
let request = tonic::Request::new(ModelMetadataRequest {
name: "tensor".into(),
version: "".into(),
});
let response = client.model_metadata(request).await.unwrap();
assert_eq!(
response.get_ref().name,
"tensor",
"Expected response of the same model name",
);
// input
for io in &response.get_ref().inputs {
match io.name.as_str() {
"input" => {
assert_eq!(
io.datatype, "BYTES",
"Expected 'input' to have datatype 'BYTES'"
);
assert_eq!(io.shape, vec![1], "Expected 'input' to have shape [1]");
}
_ => panic!("Unexpected output name: {}", io.name),
}
}
// output
for io in &response.get_ref().outputs {
match io.name.as_str() {
"output" => {
assert_eq!(
io.datatype, "BOOL",
"Expected 'output' to have datatype 'BOOL'"
);
assert_eq!(io.shape, vec![-1], "Expected 'output' to have shape [-1]");
}
_ => panic!("Unexpected output name: {}", io.name),
}
}
let model_name = "tensor";
let inputs = vec![text_input.clone()];
let request = tonic::Request::new(ModelInferRequest {
model_name: model_name.into(),
model_version: "1".into(),
id: "1234".into(),
inputs: inputs.clone(),
..Default::default()
});
let response = client.model_infer(request).await.unwrap();
validate_tensor_response(response, model_name, inputs);
// streaming response in model_infer(), expect failure
let repeat = inference::model_infer_request::InferInputTensor {
name: "repeat".into(),
datatype: "INT32".into(),
shape: vec![1],
contents: Some(inference::InferTensorContents {
int_contents: vec![2],
..Default::default()
}),
..Default::default()
};
let inputs = vec![text_input.clone(), repeat.clone()];
let request = tonic::Request::new(ModelInferRequest {
model_name: model_name.into(),
model_version: "1".into(),
id: "1234".into(),
inputs: inputs.clone(),
..Default::default()
});
let response = client.model_infer(request).await;
assert!(response.is_err());
let err = response.unwrap_err();
assert_eq!(
err.code(),
tonic::Code::Internal,
"Expected Internal error for trying to stream response in ModelInfer, get {}",
err
);
// assert "stream" in error message
assert!(
err.message()
.contains("Multiple responses in non-streaming mode"),
"Expected error message to contain 'Multiple responses in non-streaming mode', got: {}",
err.message()
);
// model_stream_infer() and raw_input_contents
{
let inputs = vec![text_input.clone(), repeat.clone()];
let outbound = async_stream::stream! {
let request_count = 1;
for _ in 0..request_count {
let mut text_input = text_input.clone();
text_input.contents = None; // Clear contents to use raw_input_contents
let text_input_str = "dummy input";
let input_len = text_input_str.len() as u32;
let mut serialized_text_input = input_len.to_le_bytes().to_vec();
serialized_text_input.extend_from_slice(text_input_str.as_bytes());
let mut repeat = repeat.clone();
repeat.contents = None; // Clear contents to use raw_input_contents
let serialized_repeat = 2i32.to_le_bytes().to_vec();
let request = ModelInferRequest {
model_name: model_name.into(),
model_version: "1".into(),
id: "1234".into(),
inputs: vec![text_input.clone(), repeat.clone()],
raw_input_contents: vec![serialized_text_input, serialized_repeat],
..Default::default()
};
yield request;
}
};
let response = client
.model_stream_infer(Request::new(outbound))
.await
.unwrap();
let mut inbound = response.into_inner();
let mut response_idx = 0;
while let Some(response) = inbound.message().await.unwrap() {
assert!(
response.error_message.is_empty(),
"Expected successful inference"
);
assert!(
response.infer_response.is_some(),
"Expected successful inference"
);
if let Some(response) = &response.infer_response {
validate_tensor_response(
Response::new(response.clone()),
model_name,
inputs.clone(),
);
}
response_idx += 1;
}
assert_eq!(response_idx, 2, "Expected 2 responses")
}
}
fn validate_tensor_response(
response: Response<ModelInferResponse>,
model_name: &str,
inputs: Vec<inference::model_infer_request::InferInputTensor>,
) {
assert_eq!(
response.get_ref().model_name,
model_name,
"Expected response of the same model name",
);
assert_eq!(
response.get_ref().model_version,
"1",
"Expected response of the same model version"
);
assert_eq!(
response.get_ref().id,
"1234",
"Expected response of the same request ID"
);
assert_eq!(
response.get_ref().outputs.len(),
inputs.len(),
"Expected the same number of outputs as inputs",
);
for output in &response.get_ref().outputs {
let mut found = false;
for input in &inputs {
if input.name != output.name {
continue;
}
assert_eq!(
output.name, input.name,
"Expected output name to be '{}', got '{}'",
input.name, output.name
);
assert_eq!(
output.datatype, input.datatype,
"Expected output datatype to be '{}', got '{}'",
input.datatype, output.datatype
);
assert_eq!(
output.shape, input.shape,
"Expected output shape to be '{:?}', got '{:?}'",
input.shape, output.shape
);
found = true;
break;
}
if !found {
panic!("Unexpected output name: {}", output.name);
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment