Unverified Commit eb3a486d authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: allow Triton model config specification in TensorModelConfig (#3874)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent 1da9d70a
...@@ -44,7 +44,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -44,7 +44,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
fn build_protos() -> Result<(), Box<dyn std::error::Error>> { fn build_protos() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("src/grpc/protos/kserve.proto")?; tonic_build::configure()
.type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]")
.compile_protos(&["kserve.proto"], &["src/grpc/protos"])?;
Ok(()) Ok(())
} }
......
...@@ -11,6 +11,8 @@ use crate::http::service::Metrics; ...@@ -11,6 +11,8 @@ use crate::http::service::Metrics;
use crate::http::service::metrics; use crate::http::service::metrics;
use crate::discovery::ModelManager; use crate::discovery::ModelManager;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::protocols::tensor::TensorModelConfig;
use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse}; use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
...@@ -38,6 +40,8 @@ use inference::{ ...@@ -38,6 +40,8 @@ use inference::{
ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse, ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
}; };
use prost::Message;
/// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint /// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint
/// for it as part of HTTP service. Should we always start HTTP service up /// for it as part of HTTP service. Should we always start HTTP service up
/// for non-inference? /// for non-inference?
...@@ -157,6 +161,27 @@ impl KserveServiceConfigBuilder { ...@@ -157,6 +161,27 @@ impl KserveServiceConfigBuilder {
} }
} }
#[allow(clippy::large_enum_variant)]
enum Config {
Dynamo(TensorModelConfig),
Triton(ModelConfig),
}
impl Config {
fn from_runtime_config(runtime_config: &ModelRuntimeConfig) -> Result<Config, anyhow::Error> {
if let Some(tensor_model_config) = runtime_config.tensor_model_config.as_ref() {
if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() {
let model_config = ModelConfig::decode(triton_model_config.as_slice())?;
Ok(Config::Triton(model_config))
} else {
Ok(Config::Dynamo(tensor_model_config.clone()))
}
} else {
Err(anyhow::anyhow!("no model config is provided"))
}
}
}
#[tonic::async_trait] #[tonic::async_trait]
impl GrpcInferenceService for KserveService { impl GrpcInferenceService for KserveService {
async fn model_infer( async fn model_infer(
...@@ -390,13 +415,54 @@ impl GrpcInferenceService for KserveService { ...@@ -390,13 +415,54 @@ impl GrpcInferenceService for KserveService {
.find(|card| request_model_name == &card.display_name) .find(|card| request_model_name == &card.display_name)
{ {
if card.model_type.supports_tensor() { if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() let config = Config::from_runtime_config(&card.runtime_config).map_err(|e| {
{ Status::invalid_argument(format!(
"Model '{}' has type Tensor but: {}",
request_model_name, e
))
})?;
match config {
Config::Triton(model_config) => {
return Ok(Response::new(ModelMetadataResponse { return Ok(Response::new(ModelMetadataResponse {
name: tensor_model_config.name.clone(), name: model_config.name,
versions: vec!["1".to_string()],
platform: model_config.platform,
inputs: model_config
.input
.iter()
.map(|input| inference::model_metadata_response::TensorMetadata {
name: input.name.clone(),
datatype: match inference::DataType::try_from(input.data_type) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: input.dims.clone(),
})
.collect(),
outputs: model_config
.output
.iter()
.map(
|output| inference::model_metadata_response::TensorMetadata {
name: output.name.clone(),
datatype: match inference::DataType::try_from(
output.data_type,
) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: output.dims.clone(),
},
)
.collect(),
}));
}
Config::Dynamo(model_config) => {
return Ok(Response::new(ModelMetadataResponse {
name: model_config.name.clone(),
versions: vec!["1".to_string()], versions: vec!["1".to_string()],
platform: "dynamo".to_string(), platform: "dynamo".to_string(),
inputs: tensor_model_config inputs: model_config
.inputs .inputs
.iter() .iter()
.map(|input| inference::model_metadata_response::TensorMetadata { .map(|input| inference::model_metadata_response::TensorMetadata {
...@@ -405,7 +471,7 @@ impl GrpcInferenceService for KserveService { ...@@ -405,7 +471,7 @@ impl GrpcInferenceService for KserveService {
shape: input.shape.clone(), shape: input.shape.clone(),
}) })
.collect(), .collect(),
outputs: tensor_model_config outputs: model_config
.outputs .outputs
.iter() .iter()
.map( .map(
...@@ -418,10 +484,7 @@ impl GrpcInferenceService for KserveService { ...@@ -418,10 +484,7 @@ impl GrpcInferenceService for KserveService {
.collect(), .collect(),
})); }));
} }
Err(Status::invalid_argument(format!( }
"Model '{}' has type Tensor but no model config is provided",
request_model_name
)))?
} else if card.model_type.supports_completions() { } else if card.model_type.supports_completions() {
return Ok(Response::new(ModelMetadataResponse { return Ok(Response::new(ModelMetadataResponse {
name: card.display_name, name: card.display_name,
...@@ -471,8 +534,19 @@ impl GrpcInferenceService for KserveService { ...@@ -471,8 +534,19 @@ impl GrpcInferenceService for KserveService {
.find(|card| request_model_name == &card.display_name) .find(|card| request_model_name == &card.display_name)
{ {
if card.model_type.supports_tensor() { if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref() let config = Config::from_runtime_config(&card.runtime_config).map_err(|e| {
{ Status::invalid_argument(format!(
"Model '{}' has type Tensor but: {}",
request_model_name, e
))
})?;
match config {
Config::Triton(model_config) => {
return Ok(Response::new(ModelConfigResponse {
config: Some(model_config),
}));
}
Config::Dynamo(tensor_model_config) => {
let model_config = ModelConfig { let model_config = ModelConfig {
name: tensor_model_config.name.clone(), name: tensor_model_config.name.clone(),
platform: "dynamo".to_string(), platform: "dynamo".to_string(),
...@@ -503,10 +577,7 @@ impl GrpcInferenceService for KserveService { ...@@ -503,10 +577,7 @@ impl GrpcInferenceService for KserveService {
config: Some(model_config.clone()), config: Some(model_config.clone()),
})); }));
} }
Err(Status::invalid_argument(format!( }
"Model '{}' has type Tensor but no model config is provided",
request_model_name
)))?
} else if card.model_type.supports_completions() { } else if card.model_type.supports_completions() {
let config = ModelConfig { let config = ModelConfig {
name: card.display_name, name: card.display_name,
......
...@@ -124,11 +124,15 @@ pub struct TensorMetadata { ...@@ -124,11 +124,15 @@ pub struct TensorMetadata {
pub parameters: Parameters, pub parameters: Parameters,
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Validate, Debug, Clone, PartialEq, Default)]
pub struct TensorModelConfig { pub struct TensorModelConfig {
pub name: String, pub name: String,
pub inputs: Vec<TensorMetadata>, pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>, pub outputs: Vec<TensorMetadata>,
// Optional Triton model config in serialized protobuf string,
// if provided, it supersedes the basic model config defined above.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub triton_model_config: Option<Vec<u8>>,
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
......
...@@ -42,6 +42,7 @@ pub mod kserve_test { ...@@ -42,6 +42,7 @@ pub mod kserve_test {
use tonic::{Request, Response, transport::Channel}; use tonic::{Request, Response, transport::Channel};
use dynamo_async_openai::types::Prompt; use dynamo_async_openai::types::Prompt;
use prost::Message;
struct SplitEngine {} struct SplitEngine {}
...@@ -361,6 +362,7 @@ pub mod kserve_test { ...@@ -361,6 +362,7 @@ pub mod kserve_test {
ModelInfo = 8994, ModelInfo = 8994,
TensorModel = 8995, TensorModel = 8995,
TensorModelTypes = 8996, TensorModelTypes = 8996,
TritonModelConfig = 8997,
} }
#[rstest] #[rstest]
...@@ -1173,6 +1175,7 @@ pub mod kserve_test { ...@@ -1173,6 +1175,7 @@ pub mod kserve_test {
shape: vec![-1], shape: vec![-1],
parameters: Default::default(), parameters: Default::default(),
}], }],
triton_model_config: None,
}), }),
..Default::default() ..Default::default()
}; };
...@@ -1206,6 +1209,193 @@ pub mod kserve_test { ...@@ -1206,6 +1209,193 @@ pub mod kserve_test {
); );
} }
#[rstest]
#[tokio::test]
async fn test_triton_model_config(
#[with(TestPort::TritonModelConfig as u16)] service_with_engines: (
KserveService,
Arc<SplitEngine>,
Arc<AlwaysFailEngine>,
Arc<LongRunningEngine>,
),
) {
// start server
let _running = RunningService::spawn(service_with_engines.0.clone());
let mut client = get_ready_client(TestPort::TritonModelConfig as u16, 5).await;
let model_name = "tensor";
let expected_model_config = inference::ModelConfig {
name: model_name.to_string(),
platform: "custom".to_string(),
backend: "custom".to_string(),
input: vec![
inference::ModelInput {
name: "input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: false,
..Default::default()
},
inference::ModelInput {
name: "optional_input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: true,
..Default::default()
},
],
output: vec![inference::ModelOutput {
name: "output".to_string(),
data_type: DataType::TypeBool as i32,
dims: vec![-1],
..Default::default()
}],
model_transaction_policy: Some(inference::ModelTransactionPolicy { decoupled: true }),
..Default::default()
};
let mut buf = vec![];
expected_model_config.encode(&mut buf).unwrap();
// Register a tensor model
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
triton_model_config: Some(buf.clone()),
..Default::default()
}),
..Default::default()
};
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap();
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
// success config
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});
let response = client
.model_config(request)
.await
.unwrap()
.into_inner()
.config;
let Some(config) = response else {
panic!("Expected Some(config), got None");
};
assert_eq!(
config, expected_model_config,
"Expected same model config to be returned",
);
// Pass config with both TensorModelConfig and triton_model_config,
// check if the Triton model config is used.
let _ = service_with_engines
.0
.model_manager()
.remove_model_card("key");
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
name: "tensor".to_string(),
inputs: vec![tensor::TensorMetadata {
name: "input".to_string(),
data_type: tensor::DataType::Int32,
shape: vec![1],
parameters: Default::default(),
}],
outputs: vec![tensor::TensorMetadata {
name: "output".to_string(),
data_type: tensor::DataType::Bool,
shape: vec![-1],
parameters: Default::default(),
}],
triton_model_config: Some(buf.clone()),
}),
..Default::default()
};
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});
let response = client
.model_config(request)
.await
.unwrap()
.into_inner()
.config;
let Some(config) = response else {
panic!("Expected Some(config), got None");
};
assert_eq!(
config, expected_model_config,
"Expected same model config to be returned",
);
// Test invalid triton model config
let _ = service_with_engines
.0
.model_manager()
.remove_model_card("key");
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
triton_model_config: Some(vec![1, 2, 3, 4, 5]),
..Default::default()
}),
..Default::default()
};
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
// success config
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});
let response = client.model_config(request).await;
assert!(response.is_err());
let err = response.unwrap_err();
assert_eq!(
err.code(),
tonic::Code::InvalidArgument,
"Expected InvalidArgument error, get {}",
err
);
assert!(
err.message().contains("failed to decode Protobuf message"),
"Expected error message to contain 'failed to decode Protobuf message', got: {}",
err.message()
);
}
#[rstest] #[rstest]
#[tokio::test] #[tokio::test]
async fn test_tensor_infer( async fn test_tensor_infer(
...@@ -1255,9 +1445,8 @@ pub mod kserve_test { ...@@ -1255,9 +1445,8 @@ pub mod kserve_test {
err err
); );
assert!( assert!(
err.message() err.message().contains("no model config is provided"),
.contains("has type Tensor but no model config is provided"), "Expected error message to contain 'no model config is provided', got: {}",
"Expected error message to contain 'has type Tensor but no model config is provided', got: {}",
err.message() err.message()
); );
...@@ -1276,9 +1465,8 @@ pub mod kserve_test { ...@@ -1276,9 +1465,8 @@ pub mod kserve_test {
err err
); );
assert!( assert!(
err.message() err.message().contains("no model config is provided"),
.contains("has type Tensor but no model config is provided"), "Expected error message to contain 'no model config is provided', got: {}",
"Expected error message to contain 'has type Tensor but no model config is provided', got: {}",
err.message() err.message()
); );
...@@ -1305,6 +1493,7 @@ pub mod kserve_test { ...@@ -1305,6 +1493,7 @@ pub mod kserve_test {
shape: vec![-1], shape: vec![-1],
parameters: Default::default(), parameters: Default::default(),
}], }],
triton_model_config: None,
}), }),
..Default::default() ..Default::default()
}; };
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker. # Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.
# Knowing the test will be run in environment that has tritonclient installed,
# which contain the generated file equivalent to model_config.proto.
import tritonclient.grpc.model_config_pb2 as mc
import uvloop import uvloop
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
...@@ -17,17 +20,39 @@ async def echo_tensor_worker(runtime: DistributedRuntime): ...@@ -17,17 +20,39 @@ async def echo_tensor_worker(runtime: DistributedRuntime):
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
triton_model_config = mc.ModelConfig()
triton_model_config.name = "echo"
triton_model_config.platform = "custom"
input_tensor = triton_model_config.input.add()
input_tensor.name = "input"
input_tensor.data_type = mc.TYPE_STRING
input_tensor.dims.extend([-1])
optional_input_tensor = triton_model_config.input.add()
optional_input_tensor.name = "optional_input"
optional_input_tensor.data_type = mc.TYPE_INT32
optional_input_tensor.dims.extend([-1])
optional_input_tensor.optional = True
output_tensor = triton_model_config.output.add()
output_tensor.name = "dummy_output"
output_tensor.data_type = mc.TYPE_STRING
output_tensor.dims.extend([-1])
triton_model_config.model_transaction_policy.decoupled = True
model_config = { model_config = {
"name": "echo", "name": "",
"inputs": [ "inputs": [],
{"name": "dummy_input", "data_type": "Bytes", "shape": [-1]}, "outputs": [],
], "triton_model_config": triton_model_config.SerializeToString(),
"outputs": [{"name": "dummy_output", "data_type": "Bytes", "shape": [-1]}],
} }
runtime_config = ModelRuntimeConfig() runtime_config = ModelRuntimeConfig()
runtime_config.set_tensor_model_config(model_config) runtime_config.set_tensor_model_config(model_config)
assert model_config == runtime_config.get_tensor_model_config() # Internally the bytes string will be converted to List of int
retrieved_model_config = runtime_config.get_tensor_model_config()
retrieved_model_config["triton_model_config"] = bytes(
retrieved_model_config["triton_model_config"]
)
assert model_config == retrieved_model_config
# [gluo FIXME] register_llm will attempt to load a LLM model, # [gluo FIXME] register_llm will attempt to load a LLM model,
# which is not well-defined for Tensor yet. Currently provide # which is not well-defined for Tensor yet. Currently provide
...@@ -46,6 +71,9 @@ async def echo_tensor_worker(runtime: DistributedRuntime): ...@@ -46,6 +71,9 @@ async def echo_tensor_worker(runtime: DistributedRuntime):
async def generate(request, context): async def generate(request, context):
"""Echo tensors and parameters back to the client.""" """Echo tensors and parameters back to the client."""
# [NOTE] gluo: currently there is no frontend side
# validation between model config and actual request,
# so any request will reach here and be echoed back.
print(f"Echoing request: {request}") print(f"Echoing request: {request}")
params = {} params = {}
......
...@@ -120,3 +120,4 @@ def start_services(request, runtime_services): ...@@ -120,3 +120,4 @@ def start_services(request, runtime_services):
@pytest.mark.model(TEST_MODEL) @pytest.mark.model(TEST_MODEL)
def test_echo() -> None: def test_echo() -> None:
triton_echo_client.run_infer() triton_echo_client.run_infer()
triton_echo_client.get_config()
...@@ -43,3 +43,17 @@ def run_infer(): ...@@ -43,3 +43,17 @@ def run_infer():
assert np.array_equal(input0_data, output0_data) assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data) assert np.array_equal(input1_data, output1_data)
def get_config():
server_url = "localhost:8000"
try:
triton_client = grpcclient.InferenceServerClient(url=server_url)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()
model_name = "echo"
response = triton_client.get_model_config(model_name=model_name)
# Check one of the field that can only be set by providing Triton model config
assert response.config.model_transaction_policy.decoupled
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment