Commit 31d27ab2 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Endpoint defaults for namespace/component/other (#277)

This means we don't need to explain the parts to the users until they are ready. We use what they provide and default the rest.

Allows all of this and more:
- `tio out=tdr://test`
- `tio out=tdr://llama_8b_pool`
- `tio in=tdr://corp_ai_research_group/model_next-20250226`
- `tio out=tdr://AIRE.NIM.migrate.mistralrs.1802`

Python, API, etc all untouched.
parent 76439997
......@@ -14,7 +14,6 @@
// limitations under the License.
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::str::FromStr;
use crate::pipeline::PipelineError;
......@@ -23,6 +22,13 @@ pub mod annotated;
pub type LeaseId = i64;
/// Default namespace if user does not provide one
const DEFAULT_NAMESPACE: &str = "NS";
const DEFAULT_COMPONENT: &str = "C";
const DEFAULT_ENDPOINT: &str = "E";
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Component {
pub name: String,
......@@ -31,7 +37,7 @@ pub struct Component {
/// Represents an endpoint with a namespace, component, and name.
///
/// An `Endpoint` is defined by a three-part string separated by `/`:
/// An `Endpoint` is defined by a three-part string separated by `/` or a '.':
/// - **namespace**
/// - **component**
/// - **name**
......@@ -39,49 +45,95 @@ pub struct Component {
/// Example format: `"namespace/component/endpoint"`
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Endpoint {
/// Name of the endpoint.
pub namespace: String,
pub component: String,
pub name: String,
}
/// Component of the endpoint.
pub component: String,
impl PartialEq<Vec<&str>> for Endpoint {
fn eq(&self, other: &Vec<&str>) -> bool {
if other.len() != 3 {
return false;
}
/// Namespace of the component.
pub namespace: String,
self.namespace == other[0] && self.component == other[1] && self.name == other[2]
}
}
impl PartialEq<Endpoint> for Vec<&str> {
fn eq(&self, other: &Endpoint) -> bool {
other == self
}
}
impl TryFrom<&str> for Endpoint {
type Error = PipelineError;
impl Default for Endpoint {
fn default() -> Self {
Endpoint {
namespace: DEFAULT_NAMESPACE.to_string(),
component: DEFAULT_COMPONENT.to_string(),
name: DEFAULT_ENDPOINT.to_string(),
}
}
}
/// Attempts to create an `Endpoint` from a string.
impl From<&str> for Endpoint {
/// Creates an `Endpoint` from a string.
///
/// # Arguments
/// - `path`: A string in the format `"namespace/component/endpoint"`.
///
/// # Errors
/// Returns a `PipelineError::InvalidFormat` if the input string does not
/// have exactly three parts separated by `/`.
/// The first two parts become the first two elements of the vector.
/// The third and subsequent parts are joined with '_' and become the third element.
/// Default values are used for missing parts.
///
/// # Examples:
/// - "component" -> ["DEFAULT_NS", "component", "DEFAULT_E"]
/// - "namespace.component" -> ["namespace", "component", "DEFAULT_E"]
/// - "namespace.component.endpoint" -> ["namespace", "component", "endpoint"]
/// - "namespace/component" -> ["namespace", "component", "DEFAULT_E"]
/// - "namespace.component.endpoint.other.parts" -> ["namespace", "component", "endpoint_other_parts"]
///
/// # Examples
/// ```ignore
/// use std::convert::TryFrom;
/// use triton_distributed::protocols::Endpoint;
///
/// let endpoint = Endpoint::try_from("namespace/component/endpoint").unwrap();
/// let endpoint = Endpoint::from("namespace/component/endpoint");
/// assert_eq!(endpoint.namespace, "namespace");
/// assert_eq!(endpoint.component, "component");
/// assert_eq!(endpoint.name, "endpoint");
/// ```
fn try_from(path: &str) -> Result<Self, Self::Error> {
let elements: Vec<&str> = path.split('/').collect();
if elements.len() != 3 {
return Err(PipelineError::InvalidEndpointFormat);
fn from(input: &str) -> Self {
let mut result = Endpoint::default();
// Split the input string on either '.' or '/'
let elements: Vec<&str> = input
.trim_matches([' ', '/', '.'])
.split(['.', '/'])
.filter(|x| !x.is_empty())
.collect();
match elements.len() {
0 => {}
1 => {
result.component = elements[0].to_string();
}
Ok(Endpoint {
namespace: elements[0].to_string(),
component: elements[1].to_string(),
name: elements[2].to_string(),
})
2 => {
result.namespace = elements[0].to_string();
result.component = elements[1].to_string();
}
3 => {
result.namespace = elements[0].to_string();
result.component = elements[1].to_string();
result.name = elements[2].to_string();
}
x if x > 3 => {
result.namespace = elements[0].to_string();
result.component = elements[1].to_string();
result.name = elements[2..].join("_");
}
_ => unreachable!(),
}
result
}
}
......@@ -90,10 +142,10 @@ impl FromStr for Endpoint {
/// Parses an `Endpoint` from a string using the standard Rust `.parse::<T>()` pattern.
///
/// This is implemented in terms of [`TryFrom<&str>`].
/// This is implemented in terms of [`From<&str>`].
///
/// # Errors
/// Returns an `PipelineError::InvalidFormat` if the input does not match `"namespace/component/endpoint"`.
/// Does not fail
///
/// # Examples
/// ```ignore
......@@ -106,7 +158,7 @@ impl FromStr for Endpoint {
/// assert_eq!(endpoint.name, "endpoint");
/// ```
fn from_str(s: &str) -> Result<Self, Self::Err> {
Endpoint::try_from(s)
Ok(Endpoint::from(s))
}
}
......@@ -136,30 +188,6 @@ mod tests {
use std::convert::TryFrom;
use std::str::FromStr;
#[test]
fn test_component_creation() {
let component = Component {
name: "test_name".to_string(),
namespace: "test_namespace".to_string(),
};
assert_eq!(component.name, "test_name");
assert_eq!(component.namespace, "test_namespace");
}
#[test]
fn test_endpoint_creation() {
let endpoint = Endpoint {
name: "test_endpoint".to_string(),
component: "test_component".to_string(),
namespace: "test_namespace".to_string(),
};
assert_eq!(endpoint.name, "test_endpoint");
assert_eq!(endpoint.component, "test_component");
assert_eq!(endpoint.namespace, "test_namespace");
}
#[test]
fn test_router_type_default() {
let default_router = RouterType::default();
......@@ -188,28 +216,9 @@ mod tests {
}
#[test]
fn test_model_metadata_creation() {
let component = Component {
name: "test_component".to_string(),
namespace: "test_namespace".to_string(),
};
let metadata = ModelMetaData {
name: "test_model".to_string(),
component,
router_type: RouterType::PushRoundRobin,
};
assert_eq!(metadata.name, "test_model");
assert_eq!(metadata.component.name, "test_component");
assert_eq!(metadata.component.namespace, "test_namespace");
assert_eq!(metadata.router_type, RouterType::PushRoundRobin);
}
#[test]
fn test_valid_endpoint_try_from() {
fn test_valid_endpoint_from() {
let input = "namespace1/component1/endpoint1";
let endpoint = Endpoint::try_from(input).expect("Valid endpoint should parse successfully");
let endpoint = Endpoint::from(input);
assert_eq!(endpoint.namespace, "namespace1");
assert_eq!(endpoint.component, "component1");
......@@ -219,7 +228,7 @@ mod tests {
#[test]
fn test_valid_endpoint_from_str() {
let input = "namespace2/component2/endpoint2";
let endpoint = Endpoint::from_str(input).expect("Valid endpoint should parse successfully");
let endpoint = Endpoint::from_str(input).unwrap();
assert_eq!(endpoint.namespace, "namespace2");
assert_eq!(endpoint.component, "component2");
......@@ -229,9 +238,7 @@ mod tests {
#[test]
fn test_valid_endpoint_parse() {
let input = "namespace3/component3/endpoint3";
let endpoint: Endpoint = input
.parse()
.expect("Valid endpoint should parse successfully");
let endpoint: Endpoint = input.parse().unwrap();
assert_eq!(endpoint.namespace, "namespace3");
assert_eq!(endpoint.component, "component3");
......@@ -239,76 +246,55 @@ mod tests {
}
#[test]
fn test_invalid_endpoint_try_from() {
let input = "invalid_endpoint_format";
let result = Endpoint::try_from(input);
assert!(result.is_err(), "Parsing should fail for an invalid format");
fn test_endpoint_from() {
let result = Endpoint::from("component");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
result,
vec![DEFAULT_NAMESPACE, "component", DEFAULT_ENDPOINT]
);
}
#[test]
fn test_invalid_endpoint_from_str() {
let input = "onlyhas/two";
let result = Endpoint::from_str(input);
assert!(result.is_err(), "Parsing should fail for an invalid format");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
fn test_namespace_component_endpoint() {
let result = Endpoint::from("namespace.component.endpoint");
assert_eq!(result, vec!["namespace", "component", "endpoint"]);
}
#[test]
fn test_invalid_endpoint_parse() {
let input = "too/many/segments/in/url";
let result: Result<Endpoint, _> = input.parse();
assert!(result.is_err(), "Parsing should fail for an invalid format");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
fn test_forward_slash_separator() {
let result = Endpoint::from("namespace/component");
assert_eq!(result, vec!["namespace", "component", DEFAULT_ENDPOINT]);
}
#[test]
fn test_empty_endpoint_string() {
let input = "";
let result = Endpoint::try_from(input);
assert!(result.is_err(), "Parsing should fail for an empty string");
fn test_multiple_parts() {
let result = Endpoint::from("namespace.component.endpoint.other.parts");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
result,
vec!["namespace", "component", "endpoint_other_parts"]
);
}
#[test]
fn test_whitespace_endpoint_string() {
let input = " ";
let result = Endpoint::try_from(input);
assert!(
result.is_err(),
"Parsing should fail for a whitespace string"
);
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
fn test_mixed_separators() {
// Do it the .into way for variety and documentation
let result: Endpoint = "namespace/component.endpoint".into();
assert_eq!(result, vec!["namespace", "component", "endpoint"]);
}
#[test]
fn test_leading_trailing_slashes() {
let input = "/namespace/component/endpoint/";
fn test_empty_string() {
let result = Endpoint::from("");
assert_eq!(
result,
vec![DEFAULT_NAMESPACE, DEFAULT_COMPONENT, DEFAULT_ENDPOINT]
);
let result = Endpoint::try_from(input);
assert!(
result.is_err(),
"Parsing should fail for leading/trailing slashes"
// White space is equivalent to an empty string
let result = Endpoint::from(" ");
assert_eq!(
result,
vec![DEFAULT_NAMESPACE, DEFAULT_COMPONENT, DEFAULT_ENDPOINT]
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment