Unverified Commit eb3a486d authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: allow Triton model config specification in TensorModelConfig (#3874)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent 1da9d70a
......@@ -44,7 +44,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
fn build_protos() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("src/grpc/protos/kserve.proto")?;
tonic_build::configure()
.type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]")
.compile_protos(&["kserve.proto"], &["src/grpc/protos"])?;
Ok(())
}
......
......@@ -11,6 +11,8 @@ use crate::http::service::Metrics;
use crate::http::service::metrics;
use crate::discovery::ModelManager;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::protocols::tensor::TensorModelConfig;
use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
use crate::request_template::RequestTemplate;
use anyhow::Result;
......@@ -38,6 +40,8 @@ use inference::{
ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
};
use prost::Message;
/// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint
/// for it as part of HTTP service. Should we always start HTTP service up
/// for non-inference?
......@@ -157,6 +161,27 @@ impl KserveServiceConfigBuilder {
}
}
#[allow(clippy::large_enum_variant)]
enum Config {
Dynamo(TensorModelConfig),
Triton(ModelConfig),
}
impl Config {
fn from_runtime_config(runtime_config: &ModelRuntimeConfig) -> Result<Config, anyhow::Error> {
if let Some(tensor_model_config) = runtime_config.tensor_model_config.as_ref() {
if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() {
let model_config = ModelConfig::decode(triton_model_config.as_slice())?;
Ok(Config::Triton(model_config))
} else {
Ok(Config::Dynamo(tensor_model_config.clone()))
}
} else {
Err(anyhow::anyhow!("no model config is provided"))
}
}
}
#[tonic::async_trait]
impl GrpcInferenceService for KserveService {
async fn model_infer(
......@@ -390,38 +415,76 @@ impl GrpcInferenceService for KserveService {
.find(|card| request_model_name == &card.display_name)
{
if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
{
return Ok(Response::new(ModelMetadataResponse {
name: tensor_model_config.name.clone(),
versions: vec!["1".to_string()],
platform: "dynamo".to_string(),
inputs: tensor_model_config
.inputs
.iter()
.map(|input| inference::model_metadata_response::TensorMetadata {
name: input.name.clone(),
datatype: input.data_type.to_string(),
shape: input.shape.clone(),
})
.collect(),
outputs: tensor_model_config
.outputs
.iter()
.map(
|output| inference::model_metadata_response::TensorMetadata {
name: output.name.clone(),
datatype: output.data_type.to_string(),
shape: output.shape.clone(),
},
)
.collect(),
}));
let config = Config::from_runtime_config(&card.runtime_config).map_err(|e| {
Status::invalid_argument(format!(
"Model '{}' has type Tensor but: {}",
request_model_name, e
))
})?;
match config {
Config::Triton(model_config) => {
return Ok(Response::new(ModelMetadataResponse {
name: model_config.name,
versions: vec!["1".to_string()],
platform: model_config.platform,
inputs: model_config
.input
.iter()
.map(|input| inference::model_metadata_response::TensorMetadata {
name: input.name.clone(),
datatype: match inference::DataType::try_from(input.data_type) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: input.dims.clone(),
})
.collect(),
outputs: model_config
.output
.iter()
.map(
|output| inference::model_metadata_response::TensorMetadata {
name: output.name.clone(),
datatype: match inference::DataType::try_from(
output.data_type,
) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: output.dims.clone(),
},
)
.collect(),
}));
}
Config::Dynamo(model_config) => {
return Ok(Response::new(ModelMetadataResponse {
name: model_config.name.clone(),
versions: vec!["1".to_string()],
platform: "dynamo".to_string(),
inputs: model_config
.inputs
.iter()
.map(|input| inference::model_metadata_response::TensorMetadata {
name: input.name.clone(),
datatype: input.data_type.to_string(),
shape: input.shape.clone(),
})
.collect(),
outputs: model_config
.outputs
.iter()
.map(
|output| inference::model_metadata_response::TensorMetadata {
name: output.name.clone(),
datatype: output.data_type.to_string(),
shape: output.shape.clone(),
},
)
.collect(),
}));
}
}
Err(Status::invalid_argument(format!(
"Model '{}' has type Tensor but no model config is provided",
request_model_name
)))?
} else if card.model_type.supports_completions() {
return Ok(Response::new(ModelMetadataResponse {
name: card.display_name,
......@@ -471,42 +534,50 @@ impl GrpcInferenceService for KserveService {
.find(|card| request_model_name == &card.display_name)
{
if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
{
let model_config = ModelConfig {
name: tensor_model_config.name.clone(),
platform: "dynamo".to_string(),
backend: "dynamo".to_string(),
input: tensor_model_config
.inputs
.iter()
.map(|input| ModelInput {
name: input.name.clone(),
data_type: input.data_type.to_kserve(),
dims: input.shape.clone(),
..Default::default()
})
.collect(),
output: tensor_model_config
.outputs
.iter()
.map(|output| ModelOutput {
name: output.name.clone(),
data_type: output.data_type.to_kserve(),
dims: output.shape.clone(),
..Default::default()
})
.collect(),
..Default::default()
};
return Ok(Response::new(ModelConfigResponse {
config: Some(model_config.clone()),
}));
let config = Config::from_runtime_config(&card.runtime_config).map_err(|e| {
Status::invalid_argument(format!(
"Model '{}' has type Tensor but: {}",
request_model_name, e
))
})?;
match config {
Config::Triton(model_config) => {
return Ok(Response::new(ModelConfigResponse {
config: Some(model_config),
}));
}
Config::Dynamo(tensor_model_config) => {
let model_config = ModelConfig {
name: tensor_model_config.name.clone(),
platform: "dynamo".to_string(),
backend: "dynamo".to_string(),
input: tensor_model_config
.inputs
.iter()
.map(|input| ModelInput {
name: input.name.clone(),
data_type: input.data_type.to_kserve(),
dims: input.shape.clone(),
..Default::default()
})
.collect(),
output: tensor_model_config
.outputs
.iter()
.map(|output| ModelOutput {
name: output.name.clone(),
data_type: output.data_type.to_kserve(),
dims: output.shape.clone(),
..Default::default()
})
.collect(),
..Default::default()
};
return Ok(Response::new(ModelConfigResponse {
config: Some(model_config.clone()),
}));
}
}
Err(Status::invalid_argument(format!(
"Model '{}' has type Tensor but no model config is provided",
request_model_name
)))?
} else if card.model_type.supports_completions() {
let config = ModelConfig {
name: card.display_name,
......
......@@ -124,11 +124,15 @@ pub struct TensorMetadata {
pub parameters: Parameters,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq)]
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq, Default)]
pub struct TensorModelConfig {
pub name: String,
pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>,
// Optional Triton model config in serialized protobuf string,
// if provided, it supersedes the basic model config defined above.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub triton_model_config: Option<Vec<u8>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
......
......@@ -42,6 +42,7 @@ pub mod kserve_test {
use tonic::{Request, Response, transport::Channel};
use dynamo_async_openai::types::Prompt;
use prost::Message;
struct SplitEngine {}
......@@ -361,6 +362,7 @@ pub mod kserve_test {
ModelInfo = 8994,
TensorModel = 8995,
TensorModelTypes = 8996,
TritonModelConfig = 8997,
}
#[rstest]
......@@ -1173,6 +1175,7 @@ pub mod kserve_test {
shape: vec![-1],
parameters: Default::default(),
}],
triton_model_config: None,
}),
..Default::default()
};
......@@ -1206,6 +1209,193 @@ pub mod kserve_test {
);
}
#[rstest]
#[tokio::test]
async fn test_triton_model_config(
#[with(TestPort::TritonModelConfig as u16)] service_with_engines: (
KserveService,
Arc<SplitEngine>,
Arc<AlwaysFailEngine>,
Arc<LongRunningEngine>,
),
) {
// start server
let _running = RunningService::spawn(service_with_engines.0.clone());
let mut client = get_ready_client(TestPort::TritonModelConfig as u16, 5).await;
let model_name = "tensor";
let expected_model_config = inference::ModelConfig {
name: model_name.to_string(),
platform: "custom".to_string(),
backend: "custom".to_string(),
input: vec![
inference::ModelInput {
name: "input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: false,
..Default::default()
},
inference::ModelInput {
name: "optional_input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: true,
..Default::default()
},
],
output: vec![inference::ModelOutput {
name: "output".to_string(),
data_type: DataType::TypeBool as i32,
dims: vec![-1],
..Default::default()
}],
model_transaction_policy: Some(inference::ModelTransactionPolicy { decoupled: true }),
..Default::default()
};
let mut buf = vec![];
expected_model_config.encode(&mut buf).unwrap();
// Register a tensor model
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
triton_model_config: Some(buf.clone()),
..Default::default()
}),
..Default::default()
};
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap();
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
// success config
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});
let response = client
.model_config(request)
.await
.unwrap()
.into_inner()
.config;
let Some(config) = response else {
panic!("Expected Some(config), got None");
};
assert_eq!(
config, expected_model_config,
"Expected same model config to be returned",
);
// Pass config with both TensorModelConfig and triton_model_config,
// check if the Triton model config is used.
let _ = service_with_engines
.0
.model_manager()
.remove_model_card("key");
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
name: "tensor".to_string(),
inputs: vec![tensor::TensorMetadata {
name: "input".to_string(),
data_type: tensor::DataType::Int32,
shape: vec![1],
parameters: Default::default(),
}],
outputs: vec![tensor::TensorMetadata {
name: "output".to_string(),
data_type: tensor::DataType::Bool,
shape: vec![-1],
parameters: Default::default(),
}],
triton_model_config: Some(buf.clone()),
}),
..Default::default()
};
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});
let response = client
.model_config(request)
.await
.unwrap()
.into_inner()
.config;
let Some(config) = response else {
panic!("Expected Some(config), got None");
};
assert_eq!(
config, expected_model_config,
"Expected same model config to be returned",
);
// Test invalid triton model config
let _ = service_with_engines
.0
.model_manager()
.remove_model_card("key");
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
triton_model_config: Some(vec![1, 2, 3, 4, 5]),
..Default::default()
}),
..Default::default()
};
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
// success config
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});
let response = client.model_config(request).await;
assert!(response.is_err());
let err = response.unwrap_err();
assert_eq!(
err.code(),
tonic::Code::InvalidArgument,
"Expected InvalidArgument error, get {}",
err
);
assert!(
err.message().contains("failed to decode Protobuf message"),
"Expected error message to contain 'failed to decode Protobuf message', got: {}",
err.message()
);
}
#[rstest]
#[tokio::test]
async fn test_tensor_infer(
......@@ -1255,9 +1445,8 @@ pub mod kserve_test {
err
);
assert!(
err.message()
.contains("has type Tensor but no model config is provided"),
"Expected error message to contain 'has type Tensor but no model config is provided', got: {}",
err.message().contains("no model config is provided"),
"Expected error message to contain 'no model config is provided', got: {}",
err.message()
);
......@@ -1276,9 +1465,8 @@ pub mod kserve_test {
err
);
assert!(
err.message()
.contains("has type Tensor but no model config is provided"),
"Expected error message to contain 'has type Tensor but no model config is provided', got: {}",
err.message().contains("no model config is provided"),
"Expected error message to contain 'no model config is provided', got: {}",
err.message()
);
......@@ -1305,6 +1493,7 @@ pub mod kserve_test {
shape: vec![-1],
parameters: Default::default(),
}],
triton_model_config: None,
}),
..Default::default()
};
......
......@@ -4,6 +4,9 @@
# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.
# Knowing the test will be run in environment that has tritonclient installed,
# which contain the generated file equivalent to model_config.proto.
import tritonclient.grpc.model_config_pb2 as mc
import uvloop
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
......@@ -17,17 +20,39 @@ async def echo_tensor_worker(runtime: DistributedRuntime):
endpoint = component.endpoint("generate")
triton_model_config = mc.ModelConfig()
triton_model_config.name = "echo"
triton_model_config.platform = "custom"
input_tensor = triton_model_config.input.add()
input_tensor.name = "input"
input_tensor.data_type = mc.TYPE_STRING
input_tensor.dims.extend([-1])
optional_input_tensor = triton_model_config.input.add()
optional_input_tensor.name = "optional_input"
optional_input_tensor.data_type = mc.TYPE_INT32
optional_input_tensor.dims.extend([-1])
optional_input_tensor.optional = True
output_tensor = triton_model_config.output.add()
output_tensor.name = "dummy_output"
output_tensor.data_type = mc.TYPE_STRING
output_tensor.dims.extend([-1])
triton_model_config.model_transaction_policy.decoupled = True
model_config = {
"name": "echo",
"inputs": [
{"name": "dummy_input", "data_type": "Bytes", "shape": [-1]},
],
"outputs": [{"name": "dummy_output", "data_type": "Bytes", "shape": [-1]}],
"name": "",
"inputs": [],
"outputs": [],
"triton_model_config": triton_model_config.SerializeToString(),
}
runtime_config = ModelRuntimeConfig()
runtime_config.set_tensor_model_config(model_config)
assert model_config == runtime_config.get_tensor_model_config()
# Internally the bytes string will be converted to List of int
retrieved_model_config = runtime_config.get_tensor_model_config()
retrieved_model_config["triton_model_config"] = bytes(
retrieved_model_config["triton_model_config"]
)
assert model_config == retrieved_model_config
# [gluo FIXME] register_llm will attempt to load a LLM model,
# which is not well-defined for Tensor yet. Currently provide
......@@ -46,6 +71,9 @@ async def echo_tensor_worker(runtime: DistributedRuntime):
async def generate(request, context):
"""Echo tensors and parameters back to the client."""
# [NOTE] gluo: currently there is no frontend side
# validation between model config and actual request,
# so any request will reach here and be echoed back.
print(f"Echoing request: {request}")
params = {}
......
......@@ -120,3 +120,4 @@ def start_services(request, runtime_services):
@pytest.mark.model(TEST_MODEL)
def test_echo() -> None:
triton_echo_client.run_infer()
triton_echo_client.get_config()
......@@ -43,3 +43,17 @@ def run_infer():
assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)
def get_config():
server_url = "localhost:8000"
try:
triton_client = grpcclient.InferenceServerClient(url=server_url)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()
model_name = "echo"
response = triton_client.get_model_config(model_name=model_name)
# Check one of the field that can only be set by providing Triton model config
assert response.config.model_transaction_policy.decoupled
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment