Commit 31d27ab2 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Endpoint defaults for namespace/component/other (#277)

This means we don't need to explain the parts to the users until they are ready. We use what they provide and default the rest.

Allows all of this and more:
- `tio out=tdr://test`
- `tio out=tdr://llama_8b_pool`
- `tio in=tdr://corp_ai_research_group/model_next-20250226`
- `tio out=tdr://AIRE.NIM.migrate.mistralrs.1802`

Python, API, etc all untouched.
parent 76439997
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
// limitations under the License. // limitations under the License.
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::str::FromStr; use std::str::FromStr;
use crate::pipeline::PipelineError; use crate::pipeline::PipelineError;
...@@ -23,6 +22,13 @@ pub mod annotated; ...@@ -23,6 +22,13 @@ pub mod annotated;
pub type LeaseId = i64; pub type LeaseId = i64;
/// Default namespace if user does not provide one
const DEFAULT_NAMESPACE: &str = "NS";
const DEFAULT_COMPONENT: &str = "C";
const DEFAULT_ENDPOINT: &str = "E";
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Component { pub struct Component {
pub name: String, pub name: String,
...@@ -31,7 +37,7 @@ pub struct Component { ...@@ -31,7 +37,7 @@ pub struct Component {
/// Represents an endpoint with a namespace, component, and name. /// Represents an endpoint with a namespace, component, and name.
/// ///
/// An `Endpoint` is defined by a three-part string separated by `/`: /// An `Endpoint` is defined by a three-part string separated by `/` or a '.':
/// - **namespace** /// - **namespace**
/// - **component** /// - **component**
/// - **name** /// - **name**
...@@ -39,49 +45,95 @@ pub struct Component { ...@@ -39,49 +45,95 @@ pub struct Component {
/// Example format: `"namespace/component/endpoint"` /// Example format: `"namespace/component/endpoint"`
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Endpoint { pub struct Endpoint {
/// Name of the endpoint. pub namespace: String,
pub component: String,
pub name: String, pub name: String,
}
/// Component of the endpoint. impl PartialEq<Vec<&str>> for Endpoint {
pub component: String, fn eq(&self, other: &Vec<&str>) -> bool {
if other.len() != 3 {
return false;
}
/// Namespace of the component. self.namespace == other[0] && self.component == other[1] && self.name == other[2]
pub namespace: String, }
}
impl PartialEq<Endpoint> for Vec<&str> {
fn eq(&self, other: &Endpoint) -> bool {
other == self
}
} }
impl TryFrom<&str> for Endpoint { impl Default for Endpoint {
type Error = PipelineError; fn default() -> Self {
Endpoint {
namespace: DEFAULT_NAMESPACE.to_string(),
component: DEFAULT_COMPONENT.to_string(),
name: DEFAULT_ENDPOINT.to_string(),
}
}
}
/// Attempts to create an `Endpoint` from a string. impl From<&str> for Endpoint {
/// Creates an `Endpoint` from a string.
/// ///
/// # Arguments /// # Arguments
/// - `path`: A string in the format `"namespace/component/endpoint"`. /// - `path`: A string in the format `"namespace/component/endpoint"`.
/// ///
/// # Errors /// The first two parts become the first two elements of the vector.
/// Returns a `PipelineError::InvalidFormat` if the input string does not /// The third and subsequent parts are joined with '_' and become the third element.
/// have exactly three parts separated by `/`. /// Default values are used for missing parts.
///
/// # Examples:
/// - "component" -> ["DEFAULT_NS", "component", "DEFAULT_E"]
/// - "namespace.component" -> ["namespace", "component", "DEFAULT_E"]
/// - "namespace.component.endpoint" -> ["namespace", "component", "endpoint"]
/// - "namespace/component" -> ["namespace", "component", "DEFAULT_E"]
/// - "namespace.component.endpoint.other.parts" -> ["namespace", "component", "endpoint_other_parts"]
/// ///
/// # Examples /// # Examples
/// ```ignore /// ```ignore
/// use std::convert::TryFrom;
/// use triton_distributed::protocols::Endpoint; /// use triton_distributed::protocols::Endpoint;
/// ///
/// let endpoint = Endpoint::try_from("namespace/component/endpoint").unwrap(); /// let endpoint = Endpoint::from("namespace/component/endpoint");
/// assert_eq!(endpoint.namespace, "namespace"); /// assert_eq!(endpoint.namespace, "namespace");
/// assert_eq!(endpoint.component, "component"); /// assert_eq!(endpoint.component, "component");
/// assert_eq!(endpoint.name, "endpoint"); /// assert_eq!(endpoint.name, "endpoint");
/// ``` /// ```
fn try_from(path: &str) -> Result<Self, Self::Error> { fn from(input: &str) -> Self {
let elements: Vec<&str> = path.split('/').collect(); let mut result = Endpoint::default();
if elements.len() != 3 {
return Err(PipelineError::InvalidEndpointFormat); // Split the input string on either '.' or '/'
let elements: Vec<&str> = input
.trim_matches([' ', '/', '.'])
.split(['.', '/'])
.filter(|x| !x.is_empty())
.collect();
match elements.len() {
0 => {}
1 => {
result.component = elements[0].to_string();
}
2 => {
result.namespace = elements[0].to_string();
result.component = elements[1].to_string();
}
3 => {
result.namespace = elements[0].to_string();
result.component = elements[1].to_string();
result.name = elements[2].to_string();
}
x if x > 3 => {
result.namespace = elements[0].to_string();
result.component = elements[1].to_string();
result.name = elements[2..].join("_");
}
_ => unreachable!(),
} }
result
Ok(Endpoint {
namespace: elements[0].to_string(),
component: elements[1].to_string(),
name: elements[2].to_string(),
})
} }
} }
...@@ -90,10 +142,10 @@ impl FromStr for Endpoint { ...@@ -90,10 +142,10 @@ impl FromStr for Endpoint {
/// Parses an `Endpoint` from a string using the standard Rust `.parse::<T>()` pattern. /// Parses an `Endpoint` from a string using the standard Rust `.parse::<T>()` pattern.
/// ///
/// This is implemented in terms of [`TryFrom<&str>`]. /// This is implemented in terms of [`From<&str>`].
/// ///
/// # Errors /// # Errors
/// Returns an `PipelineError::InvalidFormat` if the input does not match `"namespace/component/endpoint"`. /// Does not fail
/// ///
/// # Examples /// # Examples
/// ```ignore /// ```ignore
...@@ -106,7 +158,7 @@ impl FromStr for Endpoint { ...@@ -106,7 +158,7 @@ impl FromStr for Endpoint {
/// assert_eq!(endpoint.name, "endpoint"); /// assert_eq!(endpoint.name, "endpoint");
/// ``` /// ```
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
Endpoint::try_from(s) Ok(Endpoint::from(s))
} }
} }
...@@ -136,30 +188,6 @@ mod tests { ...@@ -136,30 +188,6 @@ mod tests {
use std::convert::TryFrom; use std::convert::TryFrom;
use std::str::FromStr; use std::str::FromStr;
#[test]
fn test_component_creation() {
let component = Component {
name: "test_name".to_string(),
namespace: "test_namespace".to_string(),
};
assert_eq!(component.name, "test_name");
assert_eq!(component.namespace, "test_namespace");
}
#[test]
fn test_endpoint_creation() {
let endpoint = Endpoint {
name: "test_endpoint".to_string(),
component: "test_component".to_string(),
namespace: "test_namespace".to_string(),
};
assert_eq!(endpoint.name, "test_endpoint");
assert_eq!(endpoint.component, "test_component");
assert_eq!(endpoint.namespace, "test_namespace");
}
#[test] #[test]
fn test_router_type_default() { fn test_router_type_default() {
let default_router = RouterType::default(); let default_router = RouterType::default();
...@@ -188,28 +216,9 @@ mod tests { ...@@ -188,28 +216,9 @@ mod tests {
} }
#[test] #[test]
fn test_model_metadata_creation() { fn test_valid_endpoint_from() {
let component = Component {
name: "test_component".to_string(),
namespace: "test_namespace".to_string(),
};
let metadata = ModelMetaData {
name: "test_model".to_string(),
component,
router_type: RouterType::PushRoundRobin,
};
assert_eq!(metadata.name, "test_model");
assert_eq!(metadata.component.name, "test_component");
assert_eq!(metadata.component.namespace, "test_namespace");
assert_eq!(metadata.router_type, RouterType::PushRoundRobin);
}
#[test]
fn test_valid_endpoint_try_from() {
let input = "namespace1/component1/endpoint1"; let input = "namespace1/component1/endpoint1";
let endpoint = Endpoint::try_from(input).expect("Valid endpoint should parse successfully"); let endpoint = Endpoint::from(input);
assert_eq!(endpoint.namespace, "namespace1"); assert_eq!(endpoint.namespace, "namespace1");
assert_eq!(endpoint.component, "component1"); assert_eq!(endpoint.component, "component1");
...@@ -219,7 +228,7 @@ mod tests { ...@@ -219,7 +228,7 @@ mod tests {
#[test] #[test]
fn test_valid_endpoint_from_str() { fn test_valid_endpoint_from_str() {
let input = "namespace2/component2/endpoint2"; let input = "namespace2/component2/endpoint2";
let endpoint = Endpoint::from_str(input).expect("Valid endpoint should parse successfully"); let endpoint = Endpoint::from_str(input).unwrap();
assert_eq!(endpoint.namespace, "namespace2"); assert_eq!(endpoint.namespace, "namespace2");
assert_eq!(endpoint.component, "component2"); assert_eq!(endpoint.component, "component2");
...@@ -229,9 +238,7 @@ mod tests { ...@@ -229,9 +238,7 @@ mod tests {
#[test] #[test]
fn test_valid_endpoint_parse() { fn test_valid_endpoint_parse() {
let input = "namespace3/component3/endpoint3"; let input = "namespace3/component3/endpoint3";
let endpoint: Endpoint = input let endpoint: Endpoint = input.parse().unwrap();
.parse()
.expect("Valid endpoint should parse successfully");
assert_eq!(endpoint.namespace, "namespace3"); assert_eq!(endpoint.namespace, "namespace3");
assert_eq!(endpoint.component, "component3"); assert_eq!(endpoint.component, "component3");
...@@ -239,76 +246,55 @@ mod tests { ...@@ -239,76 +246,55 @@ mod tests {
} }
#[test] #[test]
fn test_invalid_endpoint_try_from() { fn test_endpoint_from() {
let input = "invalid_endpoint_format"; let result = Endpoint::from("component");
let result = Endpoint::try_from(input);
assert!(result.is_err(), "Parsing should fail for an invalid format");
assert_eq!( assert_eq!(
result.unwrap_err().to_string(), result,
"An endpoint URL must have the format: namespace/component/endpoint" vec![DEFAULT_NAMESPACE, "component", DEFAULT_ENDPOINT]
); );
} }
#[test] #[test]
fn test_invalid_endpoint_from_str() { fn test_namespace_component_endpoint() {
let input = "onlyhas/two"; let result = Endpoint::from("namespace.component.endpoint");
assert_eq!(result, vec!["namespace", "component", "endpoint"]);
let result = Endpoint::from_str(input);
assert!(result.is_err(), "Parsing should fail for an invalid format");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
} }
#[test] #[test]
fn test_invalid_endpoint_parse() { fn test_forward_slash_separator() {
let input = "too/many/segments/in/url"; let result = Endpoint::from("namespace/component");
assert_eq!(result, vec!["namespace", "component", DEFAULT_ENDPOINT]);
let result: Result<Endpoint, _> = input.parse();
assert!(result.is_err(), "Parsing should fail for an invalid format");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
} }
#[test] #[test]
fn test_empty_endpoint_string() { fn test_multiple_parts() {
let input = ""; let result = Endpoint::from("namespace.component.endpoint.other.parts");
let result = Endpoint::try_from(input);
assert!(result.is_err(), "Parsing should fail for an empty string");
assert_eq!( assert_eq!(
result.unwrap_err().to_string(), result,
"An endpoint URL must have the format: namespace/component/endpoint" vec!["namespace", "component", "endpoint_other_parts"]
); );
} }
#[test] #[test]
fn test_whitespace_endpoint_string() { fn test_mixed_separators() {
let input = " "; // Do it the .into way for variety and documentation
let result: Endpoint = "namespace/component.endpoint".into();
let result = Endpoint::try_from(input); assert_eq!(result, vec!["namespace", "component", "endpoint"]);
assert!(
result.is_err(),
"Parsing should fail for a whitespace string"
);
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
} }
#[test] #[test]
fn test_leading_trailing_slashes() { fn test_empty_string() {
let input = "/namespace/component/endpoint/"; let result = Endpoint::from("");
assert_eq!(
result,
vec![DEFAULT_NAMESPACE, DEFAULT_COMPONENT, DEFAULT_ENDPOINT]
);
let result = Endpoint::try_from(input); // White space is equivalent to an empty string
assert!( let result = Endpoint::from(" ");
result.is_err(), assert_eq!(
"Parsing should fail for leading/trailing slashes" result,
vec![DEFAULT_NAMESPACE, DEFAULT_COMPONENT, DEFAULT_ENDPOINT]
); );
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment