Unverified Commit bb8eaa23 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

fix: KServe propagate error to client in stream infer (#5263)


Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent ace35a8e
......@@ -361,7 +361,17 @@ impl GrpcInferenceService for KserveService {
let stream = tensor_response_stream(state.clone(), tensor_request, true).await?;
pin_mut!(stream);
while let Some(response) = stream.next().await {
while let Some(delta) = stream.next().await {
let response = match delta.ok() {
Err(e) => {
yield ModelStreamInferResponse {
error_message: e.to_string(),
infer_response: None
};
continue;
}
Ok(response) => response,
};
match response.data {
Some(data) => {
let data = ExtendedNvCreateTensorResponse {response: data,
......@@ -412,7 +422,17 @@ impl GrpcInferenceService for KserveService {
if streaming {
pin_mut!(stream);
while let Some(response) = stream.next().await {
while let Some(delta) = stream.next().await {
let response = match delta.ok() {
Err(e) => {
yield ModelStreamInferResponse {
error_message: e.to_string(),
infer_response: None
};
continue;
}
Ok(response) => response,
};
match response.data {
Some(data) => {
let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| {
......
......@@ -14,6 +14,8 @@ use crate::types::Annotated;
use super::kserve;
use validator::Validate;
// [gluo NOTE] These are common utilities that should be shared between frontends
use crate::http::service::metrics::InflightGuard;
use crate::http::service::{
......@@ -304,6 +306,9 @@ impl TryFrom<inference::ModelInferRequest> for NvCreateTensorRequest {
}
tensor_request.tensors.push(tensor);
}
if let Err(validation_error) = tensor_request.validate() {
return Err(Status::invalid_argument(validation_error.to_string()));
}
Ok(tensor_request)
}
}
......@@ -530,6 +535,9 @@ impl TryFrom<ExtendedNvCreateTensorResponse> for inference::ModelInferResponse {
fn try_from(extended_response: ExtendedNvCreateTensorResponse) -> Result<Self, Self::Error> {
let response = extended_response.response;
if let Err(e) = response.validate() {
return Err(anyhow::anyhow!("Invalid NvCreateTensorResponse: {}", e));
}
// Convert response-level parameters
let parameters = convert_dynamo_to_kserve_params(&response.parameters);
......
......@@ -114,6 +114,7 @@ impl FlattenTensor {
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct TensorMetadata {
pub name: String,
pub data_type: DataType,
......@@ -125,6 +126,7 @@ pub struct TensorMetadata {
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq, Default)]
#[serde(deny_unknown_fields)]
pub struct TensorModelConfig {
pub name: String,
pub inputs: Vec<TensorMetadata>,
......@@ -136,6 +138,7 @@ pub struct TensorModelConfig {
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct Tensor {
pub metadata: TensorMetadata,
pub data: FlattenTensor,
......@@ -182,6 +185,7 @@ impl validator::Validate for Tensor {
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct NvCreateTensorRequest {
/// ID of the request
pub id: Option<String>,
......@@ -190,6 +194,7 @@ pub struct NvCreateTensorRequest {
pub model: String,
/// Input tensors.
#[validate(nested)]
pub tensors: Vec<Tensor>,
/// Optional request-level parameters
......@@ -203,6 +208,7 @@ pub struct NvCreateTensorRequest {
/// A response structure for unary chat completion responses, embedding OpenAI's
/// `CreateChatCompletionResponse`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct NvCreateTensorResponse {
/// ID of the corresponding request.
pub id: Option<String>,
......@@ -211,6 +217,7 @@ pub struct NvCreateTensorResponse {
pub model: String,
/// Output tensors.
#[validate(nested)]
pub tensors: Vec<Tensor>,
/// Optional response-level parameters
......
......@@ -295,7 +295,7 @@ pub mod kserve_test {
inference::model_infer_request::InferInputTensor {
name: "int_input".into(),
datatype: "UINT32".into(),
shape: vec![1],
shape: vec![3],
contents: Some(inference::InferTensorContents {
uint_contents: input,
..Default::default()
......@@ -1172,7 +1172,7 @@ pub mod kserve_test {
inputs: vec![tensor::TensorMetadata {
name: "input".to_string(),
data_type: tensor::DataType::Int32,
shape: vec![1],
shape: vec![3],
parameters: Default::default(),
}],
outputs: vec![tensor::TensorMetadata {
......
......@@ -75,6 +75,25 @@ async def generate(request, context):
params = {}
if "parameters" in request:
params.update(request["parameters"])
if "malformed_response" in request["parameters"]:
request["tensors"][0]["data"] = {"values": [0, 1, 2]}
yield {
"model": request["model"],
"tensors": request["tensors"],
"parameters": params,
}
return
elif "data_mismatch" in request["parameters"]:
# Modify the data type to trigger data mismatch error
request["tensors"][0]["data"]["values"] = []
yield {
"model": request["model"],
"tensors": request["tensors"],
"parameters": params,
}
return
elif "raise_exception" in request["parameters"]:
raise ValueError("Intentional exception raised by echo_tensor_worker.")
params["processed"] = {"bool": True}
......
......@@ -14,10 +14,14 @@ from __future__ import annotations
import logging
import os
import queue
import shutil
from functools import partial
import numpy as np
import pytest
import triton_echo_client
import tritonclient.grpc as grpcclient
from tests.utils.constants import QWEN
from tests.utils.managed_process import ManagedProcess
......@@ -105,4 +109,111 @@ def test_echo(start_services_with_echo_worker) -> None:
client = triton_echo_client.TritonEchoClient(grpc_port=frontend_port)
client.check_health()
client.run_infer()
client.run_stream_infer()
client.get_config()
@pytest.mark.e2e
@pytest.mark.pre_merge
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
@pytest.mark.parallel
@pytest.mark.parametrize(
"request_params",
[
{"malformed_response": True},
{"raise_exception": True},
],
ids=["malformed_response", "raise_exception"],
)
def test_model_infer_failure(start_services_with_echo_worker, request_params):
"""Test gRPC request-level parameters are echoed through tensor models.
The worker acts as an identity function: echoes input tensors unchanged and
returns all request parameters plus a "processed" flag to verify the complete
parameter flow through the gRPC frontend.
"""
frontend_port = start_services_with_echo_worker
client = grpcclient.InferenceServerClient(f"localhost:{frontend_port}")
input_data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
inputs = [grpcclient.InferInput("INPUT", input_data.shape, "FP32")]
inputs[0].set_data_from_numpy(input_data)
# expect exception during inference
with pytest.raises(Exception) as excinfo:
client.infer("echo", inputs=inputs, parameters=request_params)
if "malformed_response" in request_params:
assert "missing field `data_type`" in str(excinfo.value).lower()
elif "raise_exception" in request_params:
assert "intentional exception" in str(excinfo.value).lower()
@pytest.mark.e2e
@pytest.mark.pre_merge
@pytest.mark.gpu_0 # Echo tensor worker is CPU-only (no GPU required)
@pytest.mark.parallel
@pytest.mark.parametrize(
"request_params",
[
{"malformed_response": True},
{"raise_exception": True},
{"data_mismatch": True},
],
ids=["malformed_response", "raise_exception", "data_mismatch"],
)
def test_model_stream_infer_failure(start_services_with_echo_worker, request_params):
"""Test gRPC request-level parameters are echoed through tensor models.
The worker acts as an identity function: echoes input tensors unchanged and
returns all request parameters plus a "processed" flag to verify the complete
parameter flow through the gRPC frontend.
"""
frontend_port = start_services_with_echo_worker
client = grpcclient.InferenceServerClient(f"localhost:{frontend_port}")
input_data = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
inputs = [grpcclient.InferInput("INPUT", input_data.shape, "FP32")]
inputs[0].set_data_from_numpy(input_data)
class UserData:
def __init__(self):
self._completed_requests: queue.Queue[
grpcclient.InferResult | Exception
] = queue.Queue()
# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
print("Received callback")
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
user_data = UserData()
client.start_stream(
callback=partial(callback, user_data),
)
client.async_stream_infer(
model_name="echo",
inputs=inputs,
parameters=request_params,
)
# For stream infer, the exception and error will pass to the callback but not
# raised
with pytest.raises(Exception) as excinfo:
data_item = user_data._completed_requests.get(timeout=5)
if isinstance(data_item, Exception):
print("Raising exception received from stream infer callback")
raise data_item
if "malformed_response" in request_params:
assert "missing field `data_type`" in str(excinfo.value).lower()
elif "data_mismatch" in request_params:
assert "shape implies" in str(excinfo.value).lower()
elif "raise_exception" in request_params:
assert "intentional exception" in str(excinfo.value).lower()
......@@ -2,6 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
import queue
from functools import partial
import numpy as np
import tritonclient.grpc as grpcclient
......@@ -63,9 +66,78 @@ class TritonEchoClient:
assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)
def run_stream_infer(self) -> None:
triton_client = self._client()
model_name = "echo"
inputs = [
grpcclient.InferInput("INPUT0", [16], "INT32"),
grpcclient.InferInput("INPUT1", [16], "BYTES"),
]
input0_data = np.arange(start=0, stop=16, dtype=np.int32).reshape([16])
input1_data = np.array(
[str(x).encode("utf-8") for x in input0_data.reshape(input0_data.size)],
dtype=np.object_,
).reshape([16])
inputs[0].set_data_from_numpy(input0_data)
inputs[1].set_data_from_numpy(input1_data)
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
print("Received callback")
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
user_data = UserData()
triton_client.start_stream(
callback=partial(callback, user_data),
)
triton_client.async_stream_infer(
model_name=model_name,
inputs=inputs,
)
data_item = user_data._completed_requests.get(timeout=5)
assert (
isinstance(data_item, Exception) is False
), f"Stream inference failed: {data_item}"
output0_data = data_item.as_numpy("INPUT0")
output1_data = data_item.as_numpy("INPUT1")
assert (
output0_data is not None
), "Expected response to include output tensor 'INPUT0'"
assert (
output1_data is not None
), "Expected response to include output tensor 'INPUT1'"
assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)
def get_config(self) -> None:
triton_client = self._client()
model_name = "echo"
response = triton_client.get_model_config(model_name=model_name)
# Check one of the field that can only be set by providing Triton model config
assert response.config.model_transaction_policy.decoupled
if __name__ == "__main__":
client = TritonEchoClient(grpc_port=8000)
client.check_health()
client.run_infer()
client.get_config()
print("Triton echo client ran successfully.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment