Unverified Commit bb8eaa23 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

fix: KServe propagate error to client in stream infer (#5263)


Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent ace35a8e
...@@ -361,7 +361,17 @@ impl GrpcInferenceService for KserveService { ...@@ -361,7 +361,17 @@ impl GrpcInferenceService for KserveService {
let stream = tensor_response_stream(state.clone(), tensor_request, true).await?; let stream = tensor_response_stream(state.clone(), tensor_request, true).await?;
pin_mut!(stream); pin_mut!(stream);
while let Some(response) = stream.next().await { while let Some(delta) = stream.next().await {
let response = match delta.ok() {
Err(e) => {
yield ModelStreamInferResponse {
error_message: e.to_string(),
infer_response: None
};
continue;
}
Ok(response) => response,
};
match response.data { match response.data {
Some(data) => { Some(data) => {
let data = ExtendedNvCreateTensorResponse {response: data, let data = ExtendedNvCreateTensorResponse {response: data,
...@@ -412,7 +422,17 @@ impl GrpcInferenceService for KserveService { ...@@ -412,7 +422,17 @@ impl GrpcInferenceService for KserveService {
if streaming { if streaming {
pin_mut!(stream); pin_mut!(stream);
while let Some(response) = stream.next().await { while let Some(delta) = stream.next().await {
let response = match delta.ok() {
Err(e) => {
yield ModelStreamInferResponse {
error_message: e.to_string(),
infer_response: None
};
continue;
}
Ok(response) => response,
};
match response.data { match response.data {
Some(data) => { Some(data) => {
let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| { let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| {
......
...@@ -14,6 +14,8 @@ use crate::types::Annotated; ...@@ -14,6 +14,8 @@ use crate::types::Annotated;
use super::kserve; use super::kserve;
use validator::Validate;
// [gluo NOTE] These are common utilities that should be shared between frontends // [gluo NOTE] These are common utilities that should be shared between frontends
use crate::http::service::metrics::InflightGuard; use crate::http::service::metrics::InflightGuard;
use crate::http::service::{ use crate::http::service::{
...@@ -304,6 +306,9 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateTensorRequest { ...@@ -304,6 +306,9 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateTensorRequest {
} }
tensor_request.tensors.push(tensor); tensor_request.tensors.push(tensor);
} }
if let Err(validation_error) = tensor_request.validate() {
return Err(Status::invalid_argument(validation_error.to_string()));
}
Ok(tensor_request) Ok(tensor_request)
} }
} }
...@@ -530,6 +535,9 @@ impl TryFrom<ExtendedNvCreateTensorResponse> for inference::ModelInferResponse { ...@@ -530,6 +535,9 @@ impl TryFrom<ExtendedNvCreateTensorResponse> for inference::ModelInferResponse {
fn try_from(extended_response: ExtendedNvCreateTensorResponse) -> Result<Self, Self::Error> { fn try_from(extended_response: ExtendedNvCreateTensorResponse) -> Result<Self, Self::Error> {
let response = extended_response.response; let response = extended_response.response;
if let Err(e) = response.validate() {
return Err(anyhow::anyhow!("Invalid NvCreateTensorResponse: {}", e));
}
// Convert response-level parameters // Convert response-level parameters
let parameters = convert_dynamo_to_kserve_params(&response.parameters); let parameters = convert_dynamo_to_kserve_params(&response.parameters);
......
...@@ -114,6 +114,7 @@ impl FlattenTensor { ...@@ -114,6 +114,7 @@ impl FlattenTensor {
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct TensorMetadata { pub struct TensorMetadata {
pub name: String, pub name: String,
pub data_type: DataType, pub data_type: DataType,
...@@ -125,6 +126,7 @@ pub struct TensorMetadata { ...@@ -125,6 +126,7 @@ pub struct TensorMetadata {
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq, Default)] #[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq, Default)]
#[serde(deny_unknown_fields)]
pub struct TensorModelConfig { pub struct TensorModelConfig {
pub name: String, pub name: String,
pub inputs: Vec<TensorMetadata>, pub inputs: Vec<TensorMetadata>,
...@@ -136,6 +138,7 @@ pub struct TensorModelConfig { ...@@ -136,6 +138,7 @@ pub struct TensorModelConfig {
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct Tensor { pub struct Tensor {
pub metadata: TensorMetadata, pub metadata: TensorMetadata,
pub data: FlattenTensor, pub data: FlattenTensor,
...@@ -182,6 +185,7 @@ impl validator::Validate for Tensor { ...@@ -182,6 +185,7 @@ impl validator::Validate for Tensor {
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct NvCreateTensorRequest { pub struct NvCreateTensorRequest {
/// ID of the request /// ID of the request
pub id: Option<String>, pub id: Option<String>,
...@@ -190,6 +194,7 @@ pub struct NvCreateTensorRequest { ...@@ -190,6 +194,7 @@ pub struct NvCreateTensorRequest {
pub model: String, pub model: String,
/// Input tensors. /// Input tensors.
#[validate(nested)]
pub tensors: Vec<Tensor>, pub tensors: Vec<Tensor>,
/// Optional request-level parameters /// Optional request-level parameters
...@@ -203,6 +208,7 @@ pub struct NvCreateTensorRequest { ...@@ -203,6 +208,7 @@ pub struct NvCreateTensorRequest {
/// A response structure for unary chat completion responses, embedding OpenAI's /// A response structure for unary chat completion responses, embedding OpenAI's
/// `CreateChatCompletionResponse`. /// `CreateChatCompletionResponse`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct NvCreateTensorResponse { pub struct NvCreateTensorResponse {
/// ID of the corresponding request. /// ID of the corresponding request.
pub id: Option<String>, pub id: Option<String>,
...@@ -211,6 +217,7 @@ pub struct NvCreateTensorResponse { ...@@ -211,6 +217,7 @@ pub struct NvCreateTensorResponse {
pub model: String, pub model: String,
/// Output tensors. /// Output tensors.
#[validate(nested)]
pub tensors: Vec<Tensor>, pub tensors: Vec<Tensor>,
/// Optional response-level parameters /// Optional response-level parameters
......
...@@ -295,7 +295,7 @@ pub mod kserve_test { ...@@ -295,7 +295,7 @@ pub mod kserve_test {
inference::model_infer_request::InferInputTensor { inference::model_infer_request::InferInputTensor {
name: "int_input".into(), name: "int_input".into(),
datatype: "UINT32".into(), datatype: "UINT32".into(),
shape: vec![1], shape: vec![3],
contents: Some(inference::InferTensorContents { contents: Some(inference::InferTensorContents {
uint_contents: input, uint_contents: input,
..Default::default() ..Default::default()
...@@ -1172,7 +1172,7 @@ pub mod kserve_test { ...@@ -1172,7 +1172,7 @@ pub mod kserve_test {
inputs: vec![tensor::TensorMetadata { inputs: vec![tensor::TensorMetadata {
name: "input".to_string(), name: "input".to_string(),
data_type: tensor::DataType::Int32, data_type: tensor::DataType::Int32,
shape: vec![1], shape: vec![3],
parameters: Default::default(), parameters: Default::default(),
}], }],
outputs: vec![tensor::TensorMetadata { outputs: vec![tensor::TensorMetadata {
......
...@@ -75,6 +75,25 @@ async def generate(request, context): ...@@ -75,6 +75,25 @@ async def generate(request, context):
params = {} params = {}
if "parameters" in request: if "parameters" in request:
params.update(request["parameters"]) params.update(request["parameters"])
if "malformed_response" in request["parameters"]:
request["tensors"][0]["data"] = {"values": [0, 1, 2]}
yield {
"model": request["model"],
"tensors": request["tensors"],
"parameters": params,
}
return
elif "data_mismatch" in request["parameters"]:
# Modify the data type to trigger data mismatch error
request["tensors"][0]["data"]["values"] = []
yield {
"model": request["model"],
"tensors": request["tensors"],
"parameters": params,
}
return
elif "raise_exception" in request["parameters"]:
raise ValueError("Intentional exception raised by echo_tensor_worker.")
params["processed"] = {"bool": True} params["processed"] = {"bool": True}
......
...@@ -14,10 +14,14 @@ from __future__ import annotations ...@@ -14,10 +14,14 @@ from __future__ import annotations
import logging import logging
import os import os
import queue
import shutil import shutil
from functools import partial
import numpy as np
import pytest import pytest
import triton_echo_client import triton_echo_client
import tritonclient.grpc as grpcclient
from tests.utils.constants import QWEN from tests.utils.constants import QWEN
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -105,4 +109,111 @@ def test_echo(start_services_with_echo_worker) -> None: ...@@ -105,4 +109,111 @@ def test_echo(start_services_with_echo_worker) -> None:
client = triton_echo_client.TritonEchoClient(grpc_port=frontend_port) client = triton_echo_client.TritonEchoClient(grpc_port=frontend_port)
client.check_health() client.check_health()
client.run_infer() client.run_infer()
client.run_stream_infer()
client.get_config() client.get_config()
@pytest.mark.e2e
@pytest.mark.pre_merge
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
@pytest.mark.parallel
@pytest.mark.parametrize(
"request_params",
[
{"malformed_response": True},
{"raise_exception": True},
],
ids=["malformed_response", "raise_exception"],
)
def test_model_infer_failure(start_services_with_echo_worker, request_params):
"""Test gRPC request-level parameters are echoed through tensor models.
The worker acts as an identity function: echoes input tensors unchanged and
returns all request parameters plus a "processed" flag to verify the complete
parameter flow through the gRPC frontend.
"""
frontend_port = start_services_with_echo_worker
client = grpcclient.InferenceServerClient(f"localhost:{frontend_port}")
input_data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
inputs = [grpcclient.InferInput("INPUT", input_data.shape, "FP32")]
inputs[0].set_data_from_numpy(input_data)
# expect exception during inference
with pytest.raises(Exception) as excinfo:
client.infer("echo", inputs=inputs, parameters=request_params)
if "malformed_response" in request_params:
assert "missing field `data_type`" in str(excinfo.value).lower()
elif "raise_exception" in request_params:
assert "intentional exception" in str(excinfo.value).lower()
@pytest.mark.e2e
@pytest.mark.pre_merge
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
@pytest.mark.parallel
@pytest.mark.parametrize(
"request_params",
[
{"malformed_response": True},
{"raise_exception": True},
{"data_mismatch": True},
],
ids=["malformed_response", "raise_exception", "data_mismatch"],
)
def test_model_stream_infer_failure(start_services_with_echo_worker, request_params):
"""Test gRPC request-level parameters are echoed through tensor models.
The worker acts as an identity function: echoes input tensors unchanged and
returns all request parameters plus a "processed" flag to verify the complete
parameter flow through the gRPC frontend.
"""
frontend_port = start_services_with_echo_worker
client = grpcclient.InferenceServerClient(f"localhost:{frontend_port}")
input_data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
inputs = [grpcclient.InferInput("INPUT", input_data.shape, "FP32")]
inputs[0].set_data_from_numpy(input_data)
class UserData:
def __init__(self):
self._completed_requests: queue.Queue[
grpcclient.InferResult | Exception
] = queue.Queue()
# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
print("Received callback")
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
user_data = UserData()
client.start_stream(
callback=partial(callback, user_data),
)
client.async_stream_infer(
model_name="echo",
inputs=inputs,
parameters=request_params,
)
# For stream infer, the exception and error will pass to the callback but not
# raised
with pytest.raises(Exception) as excinfo:
data_item = user_data._completed_requests.get(timeout=5)
if isinstance(data_item, Exception):
print("Raising exception received from stream infer callback")
raise data_item
if "malformed_response" in request_params:
assert "missing field `data_type`" in str(excinfo.value).lower()
elif "data_mismatch" in request_params:
assert "shape implies" in str(excinfo.value).lower()
elif "raise_exception" in request_params:
assert "intentional exception" in str(excinfo.value).lower()
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import queue
from functools import partial
import numpy as np import numpy as np
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
...@@ -63,9 +66,78 @@ class TritonEchoClient: ...@@ -63,9 +66,78 @@ class TritonEchoClient:
assert np.array_equal(input0_data, output0_data) assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data) assert np.array_equal(input1_data, output1_data)
def run_stream_infer(self) -> None:
triton_client = self._client()
model_name = "echo"
inputs = [
grpcclient.InferInput("INPUT0", [16], "INT32"),
grpcclient.InferInput("INPUT1", [16], "BYTES"),
]
input0_data = np.arange(start=0, stop=16, dtype=np.int32).reshape([16])
input1_data = np.array(
[str(x).encode("utf-8") for x in input0_data.reshape(input0_data.size)],
dtype=np.object_,
).reshape([16])
inputs[0].set_data_from_numpy(input0_data)
inputs[1].set_data_from_numpy(input1_data)
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
print("Received callback")
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
user_data = UserData()
triton_client.start_stream(
callback=partial(callback, user_data),
)
triton_client.async_stream_infer(
model_name=model_name,
inputs=inputs,
)
data_item = user_data._completed_requests.get(timeout=5)
assert (
isinstance(data_item, Exception) is False
), f"Stream inference failed: {data_item}"
output0_data = data_item.as_numpy("INPUT0")
output1_data = data_item.as_numpy("INPUT1")
assert (
output0_data is not None
), "Expected response to include output tensor 'INPUT0'"
assert (
output1_data is not None
), "Expected response to include output tensor 'INPUT1'"
assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)
def get_config(self) -> None: def get_config(self) -> None:
triton_client = self._client() triton_client = self._client()
model_name = "echo" model_name = "echo"
response = triton_client.get_model_config(model_name=model_name) response = triton_client.get_model_config(model_name=model_name)
# Check one of the field that can only be set by providing Triton model config # Check one of the field that can only be set by providing Triton model config
assert response.config.model_transaction_policy.decoupled assert response.config.model_transaction_policy.decoupled
if __name__ == "__main__":
client = TritonEchoClient(grpc_port=8000)
client.check_health()
client.run_infer()
client.get_config()
print("Triton echo client ran successfully.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment