"vscode:/vscode.git/clone" did not exist on "e252f82dced4a5af6acaa146af4ddfebc0b851c9"
Unverified Commit 16389141 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Improvements to Leader-Worker barrier (#1498)

parent 1065ff1a
...@@ -108,22 +108,7 @@ async fn create_barrier_key<T: Serialize>( ...@@ -108,22 +108,7 @@ async fn create_barrier_key<T: Serialize>(
client client
.kv_create(key, serialized_data, lease_id) .kv_create(key, serialized_data, lease_id)
.await .await
.map_err(|_| LeaderWorkerBarrierError::BarrierIdNotUnique)?; .map_err(|_| LeaderWorkerBarrierError::IdNotUnique)?;
Ok(())
}
/// Creates a worker-specific key in etcd
async fn create_worker_key(
client: &Client,
key: &str,
lease_id: Option<i64>,
) -> Result<(), LeaderWorkerBarrierError> {
// TODO: Same as above. This can fail for many reasons.
client
.kv_create(key.to_owned(), serde_json::to_vec(&()).unwrap(), lease_id)
.await
.map_err(|_| LeaderWorkerBarrierError::BarrierWorkerIdNotUnique)?;
Ok(()) Ok(())
} }
...@@ -140,8 +125,7 @@ async fn wait_for_signal<T: DeserializeOwned>( ...@@ -140,8 +125,7 @@ async fn wait_for_signal<T: DeserializeOwned>(
#[derive(Debug)] #[derive(Debug)]
pub enum LeaderWorkerBarrierError { pub enum LeaderWorkerBarrierError {
EtcdClientNotFound, EtcdClientNotFound,
BarrierIdNotUnique, IdNotUnique,
BarrierWorkerIdNotUnique,
EtcdError(anyhow::Error), EtcdError(anyhow::Error),
SerdeError(serde_json::Error), SerdeError(serde_json::Error),
Timeout, Timeout,
...@@ -150,14 +134,16 @@ pub enum LeaderWorkerBarrierError { ...@@ -150,14 +134,16 @@ pub enum LeaderWorkerBarrierError {
} }
/// A barrier for a leader to wait for a specific number of workers to join. /// A barrier for a leader to wait for a specific number of workers to join.
pub struct LeaderBarrier<T> { pub struct LeaderBarrier<LeaderData, WorkerData> {
barrier_id: String, barrier_id: String,
num_workers: usize, num_workers: usize,
timeout: Option<Duration>, timeout: Option<Duration>,
marker: PhantomData<T>, marker: PhantomData<(LeaderData, WorkerData)>,
} }
impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> { impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
LeaderBarrier<LeaderData, WorkerData>
{
pub fn new(barrier_id: String, num_workers: usize, timeout: Option<Duration>) -> Self { pub fn new(barrier_id: String, num_workers: usize, timeout: Option<Duration>) -> Self {
Self { Self {
barrier_id, barrier_id,
...@@ -174,8 +160,8 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> { ...@@ -174,8 +160,8 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
pub async fn sync( pub async fn sync(
self, self,
rt: &DistributedRuntime, rt: &DistributedRuntime,
data: &T, data: &LeaderData,
) -> anyhow::Result<(), LeaderWorkerBarrierError> { ) -> anyhow::Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
let etcd_client = rt let etcd_client = rt
.etcd_client() .etcd_client()
.ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?; .ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
...@@ -193,13 +179,17 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> { ...@@ -193,13 +179,17 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
self.signal_completion(&etcd_client, &worker_result, lease_id) self.signal_completion(&etcd_client, &worker_result, lease_id)
.await?; .await?;
worker_result.map(|_| ()) worker_result.map(|r| {
r.into_iter()
.map(|(k, v)| (k.split("/").last().unwrap().to_string(), v))
.collect()
})
} }
async fn publish_barrier_data( async fn publish_barrier_data(
&self, &self,
client: &Client, client: &Client,
data: &T, data: &LeaderData,
lease_id: i64, lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> { ) -> Result<(), LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_DATA); let key = barrier_key(&self.barrier_id, BARRIER_DATA);
...@@ -209,21 +199,24 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> { ...@@ -209,21 +199,24 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
async fn wait_for_workers( async fn wait_for_workers(
&self, &self,
client: &Client, client: &Client,
) -> Result<HashSet<String>, LeaderWorkerBarrierError> { ) -> Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_WORKER); let key = barrier_key(&self.barrier_id, BARRIER_WORKER);
let workers = wait_for_key_count::<()>(client, key, self.num_workers, self.timeout).await?; let workers = wait_for_key_count(client, key, self.num_workers, self.timeout).await?;
Ok(workers.into_keys().collect()) Ok(workers)
} }
async fn signal_completion( async fn signal_completion(
&self, &self,
client: &Client, client: &Client,
worker_result: &Result<HashSet<String>, LeaderWorkerBarrierError>, worker_result: &Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError>,
lease_id: i64, lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> { ) -> Result<(), LeaderWorkerBarrierError> {
if let Ok(worker_result) = worker_result { if let Ok(worker_result) = worker_result {
let key = barrier_key(&self.barrier_id, BARRIER_COMPLETE); let key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
create_barrier_key(client, key, worker_result, Some(lease_id)).await?;
let workers = worker_result.keys().collect::<HashSet<_>>();
create_barrier_key(client, key, workers, Some(lease_id)).await?;
} else { } else {
let key = barrier_key(&self.barrier_id, BARRIER_ABORT); let key = barrier_key(&self.barrier_id, BARRIER_ABORT);
create_barrier_key(client, key, (), Some(lease_id)).await?; create_barrier_key(client, key, (), Some(lease_id)).await?;
...@@ -234,13 +227,15 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> { ...@@ -234,13 +227,15 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
} }
// A barrier to synchronize a worker with a leader. // A barrier to synchronize a worker with a leader.
pub struct WorkerBarrier<T> { pub struct WorkerBarrier<LeaderData, WorkerData> {
barrier_id: String, barrier_id: String,
worker_id: String, worker_id: String,
marker: PhantomData<T>, marker: PhantomData<(LeaderData, WorkerData)>,
} }
impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> { impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
WorkerBarrier<LeaderData, WorkerData>
{
pub fn new(barrier_id: String, worker_id: String) -> Self { pub fn new(barrier_id: String, worker_id: String) -> Self {
Self { Self {
barrier_id, barrier_id,
...@@ -259,7 +254,8 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> { ...@@ -259,7 +254,8 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
pub async fn sync( pub async fn sync(
self, self,
rt: &DistributedRuntime, rt: &DistributedRuntime,
) -> anyhow::Result<T, LeaderWorkerBarrierError> { data: &WorkerData,
) -> anyhow::Result<LeaderData, LeaderWorkerBarrierError> {
let etcd_client = rt let etcd_client = rt
.etcd_client() .etcd_client()
.ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?; .ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
...@@ -270,7 +266,7 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> { ...@@ -270,7 +266,7 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
let barrier_data = self.get_barrier_data(&etcd_client).await?; let barrier_data = self.get_barrier_data(&etcd_client).await?;
// Register as a worker // Register as a worker
let worker_key = self.register_worker(&etcd_client, lease_id).await?; let worker_key = self.register_worker(&etcd_client, data, lease_id).await?;
// Wait for completion or abort signal // Wait for completion or abort signal
self.wait_for_completion(&etcd_client, worker_key).await?; self.wait_for_completion(&etcd_client, worker_key).await?;
...@@ -278,12 +274,15 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> { ...@@ -278,12 +274,15 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
Ok(barrier_data) Ok(barrier_data)
} }
async fn get_barrier_data(&self, client: &Client) -> Result<T, LeaderWorkerBarrierError> { async fn get_barrier_data(
&self,
client: &Client,
) -> Result<LeaderData, LeaderWorkerBarrierError> {
let data_key = barrier_key(&self.barrier_id, BARRIER_DATA); let data_key = barrier_key(&self.barrier_id, BARRIER_DATA);
let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT); let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT);
tokio::select! { tokio::select! {
result = wait_for_key_count::<T>(client, data_key, 1, None) => { result = wait_for_key_count::<LeaderData>(client, data_key, 1, None) => {
result?.into_values().next() result?.into_values().next()
.ok_or(LeaderWorkerBarrierError::EtcdError(anyhow::anyhow!("No data found"))) .ok_or(LeaderWorkerBarrierError::EtcdError(anyhow::anyhow!("No data found")))
} }
...@@ -296,15 +295,15 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> { ...@@ -296,15 +295,15 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
async fn register_worker( async fn register_worker(
&self, &self,
client: &Client, client: &Client,
data: &WorkerData,
lease_id: i64, lease_id: i64,
) -> Result<String, LeaderWorkerBarrierError> { ) -> Result<String, LeaderWorkerBarrierError> {
let key = barrier_key( let key = barrier_key(
&self.barrier_id, &self.barrier_id,
&format!("{}/{}", BARRIER_WORKER, self.worker_id), &format!("{}/{}", BARRIER_WORKER, self.worker_id),
); );
create_worker_key(client, &key, Some(lease_id)) create_barrier_key(client, key.clone(), data, Some(lease_id)).await?;
.await Ok(key)
.map(|_| key)
} }
async fn wait_for_completion( async fn wait_for_completion(
...@@ -354,15 +353,15 @@ mod tests { ...@@ -354,15 +353,15 @@ mod tests {
assert!(drt.etcd_client().is_none()); assert!(drt.etcd_client().is_none());
let barrier = LeaderBarrier::new("test".to_string(), 2, None); let barrier = LeaderBarrier::<String, String>::new("test".to_string(), 2, None);
let worker = WorkerBarrier::<()>::new("test".to_string(), "worker".to_string()); let worker = WorkerBarrier::<String, String>::new("test".to_string(), "worker".to_string());
assert!(matches!( assert!(matches!(
barrier.sync(&drt, &"test".to_string()).await, barrier.sync(&drt, &"test".to_string()).await,
Err(LeaderWorkerBarrierError::EtcdClientNotFound) Err(LeaderWorkerBarrierError::EtcdClientNotFound)
)); ));
assert!(matches!( assert!(matches!(
worker.sync(&drt).await, worker.sync(&drt, &"test".to_string()).await,
Err(LeaderWorkerBarrierError::EtcdClientNotFound) Err(LeaderWorkerBarrierError::EtcdClientNotFound)
)); ));
} }
...@@ -374,19 +373,24 @@ mod tests { ...@@ -374,19 +373,24 @@ mod tests {
let id = unique_id(); let id = unique_id();
let leader = LeaderBarrier::new(id.clone(), 1, None); let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let worker = WorkerBarrier::<String>::new(id.clone(), "worker".to_string()); let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let drt_clone = drt.clone(); let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
leader.sync(&drt_clone, &"test_data".to_string()).await?; let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker".to_string()
);
Ok(()) Ok(())
}); });
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
let res = worker.sync(&drt).await?; let res = worker.sync(&drt, &"test_worker".to_string()).await?;
assert_eq!(res, "test_data".to_string()); assert_eq!(res, "test_data".to_string());
Ok(()) Ok(())
...@@ -405,15 +409,20 @@ mod tests { ...@@ -405,15 +409,20 @@ mod tests {
let id = unique_id(); let id = unique_id();
let leader1 = LeaderBarrier::new(id.clone(), 1, None); let leader1 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let leader2 = LeaderBarrier::new(id.clone(), 1, None); let leader2 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let worker = WorkerBarrier::<String>::new(id.clone(), "worker".to_string()); let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let drt_clone = drt.clone(); let drt_clone = drt.clone();
let leader1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let leader1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
leader1.sync(&drt_clone, &"test_data".to_string()).await?; let worker_data = leader1.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker".to_string()
);
// Now, try to sync leader 2. // Now, try to sync leader 2.
let leader2_res = leader2.sync(&drt_clone, &"test_data2".to_string()).await; let leader2_res = leader2.sync(&drt_clone, &"test_data2".to_string()).await;
...@@ -421,7 +430,7 @@ mod tests { ...@@ -421,7 +430,7 @@ mod tests {
// Leader 2 should fail because the barrier ID is the same as leader 1. // Leader 2 should fail because the barrier ID is the same as leader 1.
assert!(matches!( assert!(matches!(
leader2_res, leader2_res,
Err(LeaderWorkerBarrierError::BarrierIdNotUnique) Err(LeaderWorkerBarrierError::IdNotUnique)
)); ));
Ok(()) Ok(())
...@@ -429,7 +438,7 @@ mod tests { ...@@ -429,7 +438,7 @@ mod tests {
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
let res = worker.sync(&drt).await?; let res = worker.sync(&drt, &"test_worker".to_string()).await?;
assert_eq!(res, "test_data".to_string()); assert_eq!(res, "test_data".to_string());
Ok(()) Ok(())
...@@ -448,26 +457,33 @@ mod tests { ...@@ -448,26 +457,33 @@ mod tests {
let id = unique_id(); let id = unique_id();
let leader = LeaderBarrier::new(id.clone(), 1, None); let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<String>::new(id.clone(), "worker".to_string()); let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let worker2 = WorkerBarrier::<String>::new(id.clone(), "worker".to_string()); let worker2 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let drt_clone = drt.clone(); let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
leader.sync(&drt_clone, &"test_data".to_string()).await?; let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker_1".to_string()
);
Ok(()) Ok(())
}); });
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
worker1.sync(&drt).await?; let leader_data = worker1.sync(&drt, &"test_worker_1".to_string()).await?;
assert_eq!(leader_data, "test_data".to_string());
let worker2_res = worker2.sync(&drt).await; let worker2_res = worker2.sync(&drt, &"test_worker_2".to_string()).await;
assert!(matches!( assert!(matches!(
worker2_res, worker2_res,
Err(LeaderWorkerBarrierError::BarrierWorkerIdNotUnique) Err(LeaderWorkerBarrierError::IdNotUnique)
)); ));
Ok(()) Ok(())
...@@ -486,9 +502,9 @@ mod tests { ...@@ -486,9 +502,9 @@ mod tests {
let id = unique_id(); let id = unique_id();
let leader = LeaderBarrier::new(id.clone(), 2, Some(Duration::from_millis(100))); let leader = LeaderBarrier::<(), ()>::new(id.clone(), 2, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<()>::new(id.clone(), "worker1".to_string()); let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<()>::new(id.clone(), "worker2".to_string()); let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());
let drt_clone = drt.clone(); let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
...@@ -502,7 +518,7 @@ mod tests { ...@@ -502,7 +518,7 @@ mod tests {
let drt_clone = drt.clone(); let drt_clone = drt.clone();
let worker1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let worker1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
let res = worker1.sync(&drt_clone).await; let res = worker1.sync(&drt_clone, &()).await;
assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted))); assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));
Ok(()) Ok(())
...@@ -511,7 +527,7 @@ mod tests { ...@@ -511,7 +527,7 @@ mod tests {
let worker2_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let worker2_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(200)).await; tokio::time::sleep(Duration::from_millis(200)).await;
let res = worker2.sync(&drt).await; let res = worker2.sync(&drt, &()).await;
assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted))); assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));
Ok(()) Ok(())
...@@ -533,8 +549,9 @@ mod tests { ...@@ -533,8 +549,9 @@ mod tests {
let id = unique_id(); let id = unique_id();
// Get the leader to send a (), when the worker expects a String. // Get the leader to send a (), when the worker expects a String.
let leader = LeaderBarrier::new(id.clone(), 1, Some(Duration::from_millis(100))); let leader =
let worker1 = WorkerBarrier::<String>::new(id.clone(), "worker1".to_string()); LeaderBarrier::<(), String>::new(id.clone(), 1, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker1".to_string());
let drt_clone = drt.clone(); let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
...@@ -549,7 +566,7 @@ mod tests { ...@@ -549,7 +566,7 @@ mod tests {
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
assert!(matches!( assert!(matches!(
worker1.sync(&drt).await, worker1.sync(&drt, &"test_worker".to_string()).await,
Err(LeaderWorkerBarrierError::SerdeError(_)) Err(LeaderWorkerBarrierError::SerdeError(_))
)); ));
...@@ -569,9 +586,9 @@ mod tests { ...@@ -569,9 +586,9 @@ mod tests {
let id = unique_id(); let id = unique_id();
let leader = LeaderBarrier::new(id.clone(), 1, None); let leader = LeaderBarrier::<(), ()>::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<()>::new(id.clone(), "worker1".to_string()); let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<()>::new(id.clone(), "worker2".to_string()); let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());
let drt_clone = drt.clone(); let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
...@@ -583,9 +600,9 @@ mod tests { ...@@ -583,9 +600,9 @@ mod tests {
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> = let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move { tokio::spawn(async move {
let drt_clone = drt.clone(); let drt_clone = drt.clone();
let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone).await }); let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone, &()).await });
let worker2_join = tokio::spawn(async move { worker2.sync(&drt).await }); let worker2_join = tokio::spawn(async move { worker2.sync(&drt, &()).await });
let (worker1_res, worker2_res) = tokio::join!(worker1_join, worker2_join); let (worker1_res, worker2_res) = tokio::join!(worker1_join, worker2_join);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment