Commit e0e9f4a2 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: adds `TryFrom<&str>` and `FromStr` for `Endpoint` (#263)

parent 72064d84
.idea
.vs/
.vscode/
.helix
[Bb]inlog/
[Bb][Uu][Ii][Ll][Dd]/
[Cc][Mm][Aa][Kk][Ee]/
......
......@@ -27,7 +27,7 @@ use triton_distributed_runtime::pipeline::{
};
use triton_distributed_runtime::{protocols::Endpoint, DistributedRuntime, Runtime};
use crate::{EngineConfig, ENDPOINT_SCHEME};
use crate::EngineConfig;
pub async fn run(
runtime: Runtime,
......@@ -38,18 +38,8 @@ pub async fn run(
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let cancel_token = runtime.primary_token().clone();
let elements: Vec<&str> = path.split('/').collect();
if elements.len() != 3 {
anyhow::bail!(
"An endpoint URL must have format {ENDPOINT_SCHEME}namespace/component/endpoint"
);
}
let endpoint: Endpoint = path.parse()?;
let endpoint = Endpoint {
namespace: elements[0].to_string(),
component: elements[1].to_string(),
name: elements[2].to_string(),
};
let etcd_client = distributed.etcd_client();
let (ingress, service_name) = match engine_config {
......@@ -89,7 +79,7 @@ pub async fn run(
let model_registration = ModelEntry {
name: service_name.to_string(),
endpoint,
endpoint: endpoint.clone(),
};
etcd_client
.kv_create(
......@@ -100,12 +90,12 @@ pub async fn run(
.await?;
let rt_fut = distributed
.namespace(elements[0])?
.component(elements[1])?
.namespace(endpoint.namespace)?
.component(endpoint.component)?
.service_builder()
.create()
.await?
.endpoint(elements[2])
.endpoint(endpoint.name)
.endpoint_builder()
.handler(ingress)
.start();
......
......@@ -26,7 +26,7 @@ use triton_distributed_llm::{
Annotated,
},
};
use triton_distributed_runtime::{component::Client, DistributedRuntime};
use triton_distributed_runtime::{component::Client, protocols::Endpoint, DistributedRuntime};
mod input;
mod opt;
......@@ -136,17 +136,15 @@ pub async fn run(
}
}
Output::Endpoint(path) => {
let elements: Vec<&str> = path.split('/').collect();
if elements.len() != 3 {
anyhow::bail!("An endpoint URL must have format {ENDPOINT_SCHEME}namespace/component/endpoint");
}
let endpoint: Endpoint = path.parse()?;
// This will attempt to connect to NATS and etcd
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let client = distributed_runtime
.namespace(elements[0])?
.component(elements[1])?
.endpoint(elements[2])
.namespace(endpoint.namespace)?
.component(endpoint.component)?
.endpoint(endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
......
......@@ -87,6 +87,9 @@ pub enum PipelineError {
#[error("Generate Error: {0}")]
GenerateError(Error),
#[error("An endpoint URL must have the format: namespace/component/endpoint")]
InvalidEndpointFormat,
#[error("NATS Request Error: {0}")]
NatsRequestError(#[from] NatsError<async_nats::jetstream::context::RequestErrorKind>),
......
......@@ -14,6 +14,10 @@
// limitations under the License.
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::str::FromStr;
use crate::pipeline::PipelineError;
pub mod annotated;
......@@ -25,6 +29,14 @@ pub struct Component {
pub namespace: String,
}
/// Represents an endpoint with a namespace, component, and name.
///
/// An `Endpoint` is defined by a three-part string separated by `/`:
/// - **namespace**
/// - **component**
/// - **name**
///
/// Example format: `"namespace/component/endpoint"`
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Endpoint {
/// Name of the endpoint.
......@@ -37,6 +49,67 @@ pub struct Endpoint {
pub namespace: String,
}
impl TryFrom<&str> for Endpoint {
type Error = PipelineError;
/// Attempts to create an `Endpoint` from a string.
///
/// # Arguments
/// - `path`: A string in the format `"namespace/component/endpoint"`.
///
/// # Errors
/// Returns a `PipelineError::InvalidFormat` if the input string does not
/// have exactly three parts separated by `/`.
///
/// # Examples
/// ```ignore
/// use std::convert::TryFrom;
/// use triton_distributed::protocols::Endpoint;
///
/// let endpoint = Endpoint::try_from("namespace/component/endpoint").unwrap();
/// assert_eq!(endpoint.namespace, "namespace");
/// assert_eq!(endpoint.component, "component");
/// assert_eq!(endpoint.name, "endpoint");
/// ```
fn try_from(path: &str) -> Result<Self, Self::Error> {
let elements: Vec<&str> = path.split('/').collect();
if elements.len() != 3 {
return Err(PipelineError::InvalidEndpointFormat);
}
Ok(Endpoint {
namespace: elements[0].to_string(),
component: elements[1].to_string(),
name: elements[2].to_string(),
})
}
}
impl FromStr for Endpoint {
type Err = PipelineError;
/// Parses an `Endpoint` from a string using the standard Rust `.parse::<T>()` pattern.
///
/// This is implemented in terms of [`TryFrom<&str>`].
///
/// # Errors
/// Returns an `PipelineError::InvalidFormat` if the input does not match `"namespace/component/endpoint"`.
///
/// # Examples
/// ```ignore
/// use std::str::FromStr;
/// use triton_distributed::protocols::Endpoint;
///
/// let endpoint: Endpoint = "namespace/component/endpoint".parse().unwrap();
/// assert_eq!(endpoint.namespace, "namespace");
/// assert_eq!(endpoint.component, "component");
/// assert_eq!(endpoint.name, "endpoint");
/// ```
fn from_str(s: &str) -> Result<Self, Self::Err> {
Endpoint::try_from(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RouterType {
......@@ -60,6 +133,8 @@ pub struct ModelMetaData {
#[cfg(test)]
mod tests {
use super::*;
use std::convert::TryFrom;
use std::str::FromStr;
#[test]
fn test_component_creation() {
......@@ -130,4 +205,110 @@ mod tests {
assert_eq!(metadata.component.namespace, "test_namespace");
assert_eq!(metadata.router_type, RouterType::PushRoundRobin);
}
#[test]
fn test_valid_endpoint_try_from() {
let input = "namespace1/component1/endpoint1";
let endpoint = Endpoint::try_from(input).expect("Valid endpoint should parse successfully");
assert_eq!(endpoint.namespace, "namespace1");
assert_eq!(endpoint.component, "component1");
assert_eq!(endpoint.name, "endpoint1");
}
#[test]
fn test_valid_endpoint_from_str() {
let input = "namespace2/component2/endpoint2";
let endpoint = Endpoint::from_str(input).expect("Valid endpoint should parse successfully");
assert_eq!(endpoint.namespace, "namespace2");
assert_eq!(endpoint.component, "component2");
assert_eq!(endpoint.name, "endpoint2");
}
#[test]
fn test_valid_endpoint_parse() {
let input = "namespace3/component3/endpoint3";
let endpoint: Endpoint = input
.parse()
.expect("Valid endpoint should parse successfully");
assert_eq!(endpoint.namespace, "namespace3");
assert_eq!(endpoint.component, "component3");
assert_eq!(endpoint.name, "endpoint3");
}
#[test]
fn test_invalid_endpoint_try_from() {
let input = "invalid_endpoint_format";
let result = Endpoint::try_from(input);
assert!(result.is_err(), "Parsing should fail for an invalid format");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
}
#[test]
fn test_invalid_endpoint_from_str() {
let input = "onlyhas/two";
let result = Endpoint::from_str(input);
assert!(result.is_err(), "Parsing should fail for an invalid format");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
}
#[test]
fn test_invalid_endpoint_parse() {
let input = "too/many/segments/in/url";
let result: Result<Endpoint, _> = input.parse();
assert!(result.is_err(), "Parsing should fail for an invalid format");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
}
#[test]
fn test_empty_endpoint_string() {
let input = "";
let result = Endpoint::try_from(input);
assert!(result.is_err(), "Parsing should fail for an empty string");
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
}
#[test]
fn test_whitespace_endpoint_string() {
let input = " ";
let result = Endpoint::try_from(input);
assert!(
result.is_err(),
"Parsing should fail for a whitespace string"
);
assert_eq!(
result.unwrap_err().to_string(),
"An endpoint URL must have the format: namespace/component/endpoint"
);
}
#[test]
fn test_leading_trailing_slashes() {
let input = "/namespace/component/endpoint/";
let result = Endpoint::try_from(input);
assert!(
result.is_err(),
"Parsing should fail for leading/trailing slashes"
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment