Unverified Commit a1a10365 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Split PushRouter from Client (#817)

In a distributed system we don't know if the remote workers need pre-processing done ingress-side or not. Previously Client required us to decide this before discovering the remote endpoints, which was fine because pre-processing was worker-side.

As part of moving pre-processing back to ingress-side we need to split this into two steps:
- Client discovers the endpoints, and (later PR) will fetch their Model Deployment Card.
- PushRouter will use the Model Deployment Card to decide if they need pre-processing or not, which affects the types of the generic parameters.

Part of #743
parent 97bf8184
...@@ -68,7 +68,7 @@ mod namespace; ...@@ -68,7 +68,7 @@ mod namespace;
mod registry; mod registry;
pub mod service; pub mod service;
pub use client::{Client, RouterMode}; pub use client::{Client, EndpointSource};
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -96,6 +96,12 @@ pub struct ComponentEndpointInfo { ...@@ -96,6 +96,12 @@ pub struct ComponentEndpointInfo {
pub transport: TransportType, pub transport: TransportType,
} }
impl ComponentEndpointInfo {
pub fn id(&self) -> i64 {
self.lease_id
}
}
/// A [Component] a discoverable entity in the distributed runtime. /// A [Component] a discoverable entity in the distributed runtime.
/// You can host [Endpoint] on a [Component] by first creating /// You can host [Endpoint] on a [Component] by first creating
/// a [Service] then adding one or more [Endpoint] to the [Service]. /// a [Service] then adding one or more [Endpoint] to the [Service].
...@@ -160,6 +166,10 @@ impl Component { ...@@ -160,6 +166,10 @@ impl Component {
&self.namespace &self.namespace
} }
pub fn name(&self) -> String {
self.name.clone()
}
pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint { pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint {
Endpoint { Endpoint {
component: self.clone(), component: self.clone(),
...@@ -272,11 +282,7 @@ impl Endpoint { ...@@ -272,11 +282,7 @@ impl Endpoint {
) )
} }
pub async fn client<Req, Resp>(&self) -> Result<client::Client<Req, Resp>> pub async fn client(&self) -> Result<client::Client> {
where
Req: Serialize + Send + Sync + 'static,
Resp: for<'de> Deserialize<'de> + Send + Sync + 'static,
{
if self.is_static { if self.is_static {
client::Client::new_static(self.clone()).await client::Client::new_static(self.clone()).await
} else { } else {
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
// limitations under the License. // limitations under the License.
use crate::pipeline::{ use crate::pipeline::{
network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter}, AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
AsyncEngine, Data, ManyOut, SingleIn, SingleIn,
}; };
use rand::Rng; use rand::Rng;
use std::collections::HashMap; use std::collections::HashMap;
...@@ -25,7 +25,7 @@ use std::sync::{ ...@@ -25,7 +25,7 @@ use std::sync::{
}; };
use tokio::{net::unix::pipe::Receiver, sync::Mutex}; use tokio::{net::unix::pipe::Receiver, sync::Mutex};
use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error}; use crate::{pipeline::async_trait, transports::etcd::WatchEvent};
use super::*; use super::*;
...@@ -48,46 +48,26 @@ enum EndpointEvent { ...@@ -48,46 +48,26 @@ enum EndpointEvent {
Delete(String), Delete(String),
} }
#[derive(Default, Debug, Clone, Copy)] #[derive(Clone, Debug)]
pub enum RouterMode { pub struct Client {
#[default] // This is me
Random, pub endpoint: Endpoint,
RoundRobin, // These are the remotes I know about
//KV, pub endpoints: EndpointSource,
//
// Always and only go to the given endpoint ID.
// TODO: Is this useful?
Direct(i64),
}
#[derive(Clone)]
pub struct Client<T: Data, U: Data> {
endpoint: Endpoint,
router: PushRouter<T, U>,
counter: Arc<AtomicU64>,
endpoints: EndpointSource,
router_mode: RouterMode,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum EndpointSource { pub enum EndpointSource {
Static, Static,
Dynamic(tokio::sync::watch::Receiver<Vec<i64>>), Dynamic(tokio::sync::watch::Receiver<Vec<ComponentEndpointInfo>>),
} }
impl<T, U> Client<T, U> impl Client {
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
// Client will only talk to a single static endpoint // Client will only talk to a single static endpoint
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> { pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client { Ok(Client {
router: router(&endpoint).await?,
endpoint, endpoint,
counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Static, endpoints: EndpointSource::Static,
router_mode: Default::default(),
}) })
} }
...@@ -136,7 +116,7 @@ where ...@@ -136,7 +116,7 @@ where
let key = String::from_utf8(kv.key().to_vec()); let key = String::from_utf8(kv.key().to_vec());
let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value()); let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
if let (Ok(key), Ok(val)) = (key, val) { if let (Ok(key), Ok(val)) = (key, val) {
map.insert(key.clone(), val.lease_id); map.insert(key.clone(), val);
} else { } else {
tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix); tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
break; break;
...@@ -153,9 +133,9 @@ where ...@@ -153,9 +133,9 @@ where
} }
} }
let endpoint_ids: Vec<i64> = map.values().cloned().collect(); let endpoints: Vec<ComponentEndpointInfo> = map.values().cloned().collect();
if watch_tx.send(endpoint_ids).is_err() { if watch_tx.send(endpoints).is_err() {
tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix); tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
break; break;
} }
...@@ -167,11 +147,8 @@ where ...@@ -167,11 +147,8 @@ where
}); });
Ok(Client { Ok(Client {
router: router(&endpoint).await?,
endpoint, endpoint,
counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Dynamic(watch_rx), endpoints: EndpointSource::Dynamic(watch_rx),
router_mode: Default::default(),
}) })
} }
...@@ -185,135 +162,36 @@ where ...@@ -185,135 +162,36 @@ where
self.endpoint.etcd_path() self.endpoint.etcd_path()
} }
pub fn endpoint_ids(&self) -> Vec<i64> { pub fn endpoints(&self) -> Vec<ComponentEndpointInfo> {
match &self.endpoints { match &self.endpoints {
EndpointSource::Static => vec![0], EndpointSource::Static => vec![],
EndpointSource::Dynamic(watch_rx) => watch_rx.borrow().clone(), EndpointSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
} }
} }
pub fn set_router_mode(&mut self, mode: RouterMode) { pub fn endpoint_ids(&self) -> Vec<i64> {
self.router_mode = mode self.endpoints().into_iter().map(|ep| ep.id()).collect()
} }
/// Wait for at least one [`Endpoint`] to be available /// Wait for at least one [`Endpoint`] to be available
pub async fn wait_for_endpoints(&self) -> Result<()> { pub async fn wait_for_endpoints(&self) -> Result<Vec<ComponentEndpointInfo>> {
let mut endpoints: Vec<ComponentEndpointInfo> = vec![];
if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() { if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() {
// wait for there to be 1 or more endpoints // wait for there to be 1 or more endpoints
loop { loop {
if rx.borrow_and_update().is_empty() { endpoints = rx.borrow_and_update().to_vec();
if endpoints.is_empty() {
rx.changed().await?; rx.changed().await?;
} else { } else {
break; break;
} }
} }
} }
Ok(()) Ok(endpoints)
} }
/// Is this component know at startup and not discovered via etcd? /// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool { pub fn is_static(&self) -> bool {
matches!(self.endpoints, EndpointSource::Static) matches!(self.endpoints, EndpointSource::Static)
} }
/// Issue a request to the next available endpoint in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let counter = self.counter.fetch_add(1, Ordering::Relaxed);
let endpoint_id = {
let endpoints = self.endpoint_ids();
let count = endpoints.len();
if count == 0 {
return Err(error!(
"no endpoints found for endpoint {:?}",
self.endpoint.etcd_path()
));
}
let offset = counter % count as u64;
endpoints[offset as usize]
};
tracing::trace!("round robin router selected {endpoint_id}");
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
/// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let endpoint_id = {
let endpoints = self.endpoint_ids();
let count = endpoints.len();
if count == 0 {
return Err(error!(
"no endpoints found for endpoint {:?}",
self.endpoint.etcd_path()
));
}
let counter = rand::rng().random::<u64>();
let offset = counter % count as u64;
endpoints[offset as usize]
};
tracing::trace!("random router selected {endpoint_id}");
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
/// Issue a request to a specific endpoint
pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
let found = {
let endpoints = self.endpoint_ids();
endpoints.contains(&endpoint_id)
};
if !found {
return Err(error!(
"endpoint_id={} not found for endpoint {:?}",
endpoint_id,
self.endpoint.etcd_path()
));
}
let subject = self.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.router.generate(request).await
}
pub async fn r#static(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let subject = self.endpoint.subject();
tracing::debug!("static got subject: {subject}");
let request = request.map(|req| AddressedRequest::new(req, subject));
tracing::debug!("router generate");
self.router.generate(request).await
}
}
async fn router(endpoint: &Endpoint) -> Result<Arc<AddressedPushRouter>> {
AddressedPushRouter::new(
endpoint.component.drt.nats_client.client().clone(),
endpoint.component.drt.tcp_server().await?,
)
}
#[async_trait]
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match &self.endpoints {
EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
},
}
}
} }
...@@ -26,6 +26,8 @@ pub use nodes::{ ...@@ -26,6 +26,8 @@ pub use nodes::{
pub mod context; pub mod context;
pub mod error; pub mod error;
pub mod network; pub mod network;
pub use network::egress::addressed_router::{AddressedPushRouter, AddressedRequest};
pub use network::egress::push_router::{PushRouter, RouterMode};
pub mod registry; pub mod registry;
pub use crate::engine::{ pub use crate::engine::{
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod push; pub mod addressed_router;
pub mod push_router;
use super::*; use super::*;
...@@ -41,9 +41,6 @@ struct RequestControlMessage { ...@@ -41,9 +41,6 @@ struct RequestControlMessage {
connection_info: ConnectionInfo, connection_info: ConnectionInfo,
} }
pub type PushRouter<In, Out> =
Arc<dyn AsyncEngine<SingleIn<AddressedRequest<In>>, ManyOut<Out>, Error>>;
pub struct AddressedRequest<T> { pub struct AddressedRequest<T> {
request: T, request: T,
address: String, address: String,
...@@ -111,7 +108,7 @@ where ...@@ -111,7 +108,7 @@ where
} }
}; };
// separate out the the connection info and the stream provider from the registered stream // separate out the connection info and the stream provider from the registered stream
let (connection_info, response_stream_provider) = pending_response_stream.into_parts(); let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
// package up the connection info as part of the "header" component of the two part message // package up the connection info as part of the "header" component of the two part message
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::{
marker::PhantomData,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use crate::{
component::{Client, Endpoint, EndpointSource},
engine::{AsyncEngine, Data},
pipeline::{AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn},
traits::DistributedRuntimeProvider,
};
#[derive(Clone)]
pub struct PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
// TODO: This shouldn't be pub, but lib/bindings/python/rust/lib.rs exposes it.
/// The Client is how we gather remote endpoint information from etcd.
pub client: Client,
/// How we choose which endpoint to send traffic to.
router_mode: RouterMode,
/// Number of round robin requests handled. Used to decide which server is next.
round_robin_counter: Arc<AtomicU64>,
/// The next step in the chain. PushRouter (this object) picks an endpoint,
/// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>,
/// An internal Rust type. This says that PushRouter is generic over the T and U types,
/// which are the input and output types of it's `generate` function. It allows the
/// compiler to specialize us at compile time.
_phantom: PhantomData<(T, U)>,
}
#[derive(Default, Debug, Clone, Copy)]
pub enum RouterMode {
#[default]
Random,
RoundRobin,
//KV,
//
// Always and only go to the given endpoint ID. Used by Python bindings.
Direct(i64),
}
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
AddressedPushRouter::new(
endpoint.drt().nats_client.client().clone(),
endpoint.drt().tcp_server().await?,
)
}
impl<T, U> PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter {
client,
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
_phantom: PhantomData,
})
}
/// Issue a request to the next available endpoint in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
let endpoint_id = {
let endpoints = self.client.endpoints();
let count = endpoints.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no endpoints found for endpoint {:?}",
self.client.endpoint.etcd_path()
));
}
let offset = counter % count as u64;
endpoints[offset as usize].id()
};
tracing::trace!("round robin router selected {endpoint_id}");
let subject = self.client.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
}
/// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let endpoint_id = {
let endpoints = self.client.endpoints();
let count = endpoints.len();
if count == 0 {
return Err(anyhow::anyhow!(
"no endpoints found for endpoint {:?}",
self.client.endpoint.etcd_path()
));
}
let counter = rand::rng().random::<u64>();
let offset = counter % count as u64;
endpoints[offset as usize].id()
};
tracing::trace!("random router selected {endpoint_id}");
let subject = self.client.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
}
/// Issue a request to a specific endpoint
pub async fn direct(
&self,
request: SingleIn<T>,
endpoint_id: i64,
) -> anyhow::Result<ManyOut<U>> {
let found = {
let endpoints = self.client.endpoints();
endpoints.iter().any(|ep| ep.id() == endpoint_id)
};
if !found {
return Err(anyhow::anyhow!(
"endpoint_id={} not found for endpoint {:?}",
endpoint_id,
self.client.endpoint.etcd_path()
));
}
let subject = self.client.endpoint.subject_to(endpoint_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
self.addressed.generate(request).await
}
pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject();
tracing::debug!("static got subject: {subject}");
let request = request.map(|req| AddressedRequest::new(req, subject));
tracing::debug!("router generate");
self.addressed.generate(request).await
}
}
#[async_trait]
impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for PushRouter<T, U>
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>,
{
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
match &self.client.endpoints {
EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => match self.router_mode {
RouterMode::Random => self.random(request).await,
RouterMode::RoundRobin => self.round_robin(request).await,
RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
},
}
}
}
...@@ -147,7 +147,7 @@ impl TcpStreamServer { ...@@ -147,7 +147,7 @@ impl TcpStreamServer {
PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e)) PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
})?; })?;
tracing::info!("tcp transport service on {}:{}", local_ip, local_port); tracing::debug!("tcp transport service on {local_ip}:{local_port}");
Ok(Arc::new(Self { Ok(Arc::new(Self {
local_ip, local_ip,
......
...@@ -33,7 +33,7 @@ impl Slug { ...@@ -33,7 +33,7 @@ impl Slug {
/// Create [`Slug`] from a string. /// Create [`Slug`] from a string.
pub fn from_string(s: impl AsRef<str>) -> Slug { pub fn from_string(s: impl AsRef<str>) -> Slug {
Slug::slugify_unique(s.as_ref()) Slug::slugify(s.as_ref())
} }
/// Turn the string into a valid slug, replacing any not-web-or-nats-safe characters with '-' /// Turn the string into a valid slug, replacing any not-web-or-nats-safe characters with '-'
...@@ -54,7 +54,7 @@ impl Slug { ...@@ -54,7 +54,7 @@ impl Slug {
} }
/// Like slugify but also add a four byte hash on the end, in case two different strings slug /// Like slugify but also add a four byte hash on the end, in case two different strings slug
/// to the same thing. /// to the same thing (e.g. because of case differences).
pub fn slugify_unique(s: &str) -> Slug { pub fn slugify_unique(s: &str) -> Slug {
let out = s let out = s
.to_lowercase() .to_lowercase()
......
...@@ -25,7 +25,8 @@ use tokio::sync::{mpsc, RwLock}; ...@@ -25,7 +25,8 @@ use tokio::sync::{mpsc, RwLock};
use validator::Validate; use validator::Validate;
use etcd_client::{ use etcd_client::{
Compare, CompareOp, GetOptions, PutOptions, Txn, TxnOp, TxnOpResponse, WatchOptions, Watcher, Compare, CompareOp, DeleteOptions, GetOptions, PutOptions, PutResponse, Txn, TxnOp,
TxnOpResponse, WatchOptions, Watcher,
}; };
pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient}; pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient};
...@@ -172,13 +173,14 @@ impl Client { ...@@ -172,13 +173,14 @@ impl Client {
value: Vec<u8>, value: Vec<u8>,
lease_id: Option<i64>, lease_id: Option<i64>,
) -> Result<()> { ) -> Result<()> {
let put_options = lease_id.map(|id| PutOptions::new().with_lease(id)); let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id);
// Build the transaction // Build the transaction
let txn = Txn::new() let txn = Txn::new()
.when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Ensure the lock does not exist .when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Ensure the lock does not exist
.and_then(vec![ .and_then(vec![
TxnOp::put(key.as_str(), value, put_options), // Create the object TxnOp::put(key.as_str(), value, Some(put_options)), // Create the object
]); ]);
// Execute the transaction // Execute the transaction
...@@ -201,14 +203,15 @@ impl Client { ...@@ -201,14 +203,15 @@ impl Client {
value: Vec<u8>, value: Vec<u8>,
lease_id: Option<i64>, lease_id: Option<i64>,
) -> Result<()> { ) -> Result<()> {
let put_options = lease_id.map(|id| PutOptions::new().with_lease(id)); let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id);
// Build the transaction that either creates the key if it doesn't exist, // Build the transaction that either creates the key if it doesn't exist,
// or validates the existing value matches what we expect // or validates the existing value matches what we expect
let txn = Txn::new() let txn = Txn::new()
.when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Key doesn't exist .when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Key doesn't exist
.and_then(vec![ .and_then(vec![
TxnOp::put(key.as_str(), value.clone(), put_options), // Create it TxnOp::put(key.as_str(), value.clone(), Some(put_options)), // Create it
]) ])
.or_else(vec![ .or_else(vec![
// If key exists but values don't match, this will fail the transaction // If key exists but values don't match, this will fail the transaction
...@@ -245,19 +248,54 @@ impl Client { ...@@ -245,19 +248,54 @@ impl Client {
value: impl AsRef<[u8]>, value: impl AsRef<[u8]>,
lease_id: Option<i64>, lease_id: Option<i64>,
) -> Result<()> { ) -> Result<()> {
let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id);
let _ = self let _ = self
.client .client
.kv_client() .kv_client()
.put( .put(key.as_ref(), value.as_ref(), Some(put_options))
key.as_ref(),
value.as_ref(),
lease_id.map(|id| PutOptions::new().with_lease(id)),
)
.await?; .await?;
Ok(()) Ok(())
} }
pub async fn kv_put_with_options(
&self,
key: impl AsRef<str>,
value: impl AsRef<[u8]>,
options: Option<PutOptions>,
) -> Result<PutResponse> {
let options = options
.unwrap_or_default()
.with_lease(self.primary_lease().id());
self.client
.kv_client()
.put(key.as_ref(), value.as_ref(), Some(options))
.await
.map_err(|err| err.into())
}
pub async fn kv_get(
&self,
key: impl Into<Vec<u8>>,
options: Option<GetOptions>,
) -> Result<Vec<KeyValue>> {
let mut get_response = self.client.kv_client().get(key, options).await?;
Ok(get_response.take_kvs())
}
pub async fn kv_delete(
&self,
key: impl Into<Vec<u8>>,
options: Option<DeleteOptions>,
) -> Result<i64> {
self.client
.kv_client()
.delete(key, options)
.await
.map(|del_response| del_response.deleted())
.map_err(|err| err.into())
}
pub async fn kv_get_prefix(&self, prefix: impl AsRef<str>) -> Result<Vec<KeyValue>> { pub async fn kv_get_prefix(&self, prefix: impl AsRef<str>) -> Result<Vec<KeyValue>> {
let mut get_response = self let mut get_response = self
.client .client
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment