Unverified Commit 16389141 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Improvements to Leader-Worker barrier (#1498)

parent 1065ff1a
......@@ -108,22 +108,7 @@ async fn create_barrier_key<T: Serialize>(
client
.kv_create(key, serialized_data, lease_id)
.await
.map_err(|_| LeaderWorkerBarrierError::BarrierIdNotUnique)?;
Ok(())
}
/// Creates a worker-specific key in etcd
async fn create_worker_key(
client: &Client,
key: &str,
lease_id: Option<i64>,
) -> Result<(), LeaderWorkerBarrierError> {
// TODO: Same as above. This can fail for many reasons.
client
.kv_create(key.to_owned(), serde_json::to_vec(&()).unwrap(), lease_id)
.await
.map_err(|_| LeaderWorkerBarrierError::BarrierWorkerIdNotUnique)?;
.map_err(|_| LeaderWorkerBarrierError::IdNotUnique)?;
Ok(())
}
......@@ -140,8 +125,7 @@ async fn wait_for_signal<T: DeserializeOwned>(
#[derive(Debug)]
pub enum LeaderWorkerBarrierError {
EtcdClientNotFound,
BarrierIdNotUnique,
BarrierWorkerIdNotUnique,
IdNotUnique,
EtcdError(anyhow::Error),
SerdeError(serde_json::Error),
Timeout,
......@@ -150,14 +134,16 @@ pub enum LeaderWorkerBarrierError {
}
/// A barrier for a leader to wait for a specific number of workers to join.
pub struct LeaderBarrier<T> {
pub struct LeaderBarrier<LeaderData, WorkerData> {
barrier_id: String,
num_workers: usize,
timeout: Option<Duration>,
marker: PhantomData<T>,
marker: PhantomData<(LeaderData, WorkerData)>,
}
impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
LeaderBarrier<LeaderData, WorkerData>
{
pub fn new(barrier_id: String, num_workers: usize, timeout: Option<Duration>) -> Self {
Self {
barrier_id,
......@@ -174,8 +160,8 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
pub async fn sync(
self,
rt: &DistributedRuntime,
data: &T,
) -> anyhow::Result<(), LeaderWorkerBarrierError> {
data: &LeaderData,
) -> anyhow::Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
let etcd_client = rt
.etcd_client()
.ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
......@@ -193,13 +179,17 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
self.signal_completion(&etcd_client, &worker_result, lease_id)
.await?;
worker_result.map(|_| ())
worker_result.map(|r| {
r.into_iter()
.map(|(k, v)| (k.split("/").last().unwrap().to_string(), v))
.collect()
})
}
async fn publish_barrier_data(
&self,
client: &Client,
data: &T,
data: &LeaderData,
lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_DATA);
......@@ -209,21 +199,24 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
async fn wait_for_workers(
&self,
client: &Client,
) -> Result<HashSet<String>, LeaderWorkerBarrierError> {
) -> Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_WORKER);
let workers = wait_for_key_count::<()>(client, key, self.num_workers, self.timeout).await?;
Ok(workers.into_keys().collect())
let workers = wait_for_key_count(client, key, self.num_workers, self.timeout).await?;
Ok(workers)
}
async fn signal_completion(
&self,
client: &Client,
worker_result: &Result<HashSet<String>, LeaderWorkerBarrierError>,
worker_result: &Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError>,
lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> {
if let Ok(worker_result) = worker_result {
let key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
create_barrier_key(client, key, worker_result, Some(lease_id)).await?;
let workers = worker_result.keys().collect::<HashSet<_>>();
create_barrier_key(client, key, workers, Some(lease_id)).await?;
} else {
let key = barrier_key(&self.barrier_id, BARRIER_ABORT);
create_barrier_key(client, key, (), Some(lease_id)).await?;
......@@ -234,13 +227,15 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
}
// A barrier to synchronize a worker with a leader.
pub struct WorkerBarrier<T> {
pub struct WorkerBarrier<LeaderData, WorkerData> {
barrier_id: String,
worker_id: String,
marker: PhantomData<T>,
marker: PhantomData<(LeaderData, WorkerData)>,
}
impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
WorkerBarrier<LeaderData, WorkerData>
{
pub fn new(barrier_id: String, worker_id: String) -> Self {
Self {
barrier_id,
......@@ -259,7 +254,8 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
pub async fn sync(
self,
rt: &DistributedRuntime,
) -> anyhow::Result<T, LeaderWorkerBarrierError> {
data: &WorkerData,
) -> anyhow::Result<LeaderData, LeaderWorkerBarrierError> {
let etcd_client = rt
.etcd_client()
.ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
......@@ -270,7 +266,7 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
let barrier_data = self.get_barrier_data(&etcd_client).await?;
// Register as a worker
let worker_key = self.register_worker(&etcd_client, lease_id).await?;
let worker_key = self.register_worker(&etcd_client, data, lease_id).await?;
// Wait for completion or abort signal
self.wait_for_completion(&etcd_client, worker_key).await?;
......@@ -278,12 +274,15 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
Ok(barrier_data)
}
async fn get_barrier_data(&self, client: &Client) -> Result<T, LeaderWorkerBarrierError> {
async fn get_barrier_data(
&self,
client: &Client,
) -> Result<LeaderData, LeaderWorkerBarrierError> {
let data_key = barrier_key(&self.barrier_id, BARRIER_DATA);
let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT);
tokio::select! {
result = wait_for_key_count::<T>(client, data_key, 1, None) => {
result = wait_for_key_count::<LeaderData>(client, data_key, 1, None) => {
result?.into_values().next()
.ok_or(LeaderWorkerBarrierError::EtcdError(anyhow::anyhow!("No data found")))
}
......@@ -296,15 +295,15 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
async fn register_worker(
&self,
client: &Client,
data: &WorkerData,
lease_id: i64,
) -> Result<String, LeaderWorkerBarrierError> {
let key = barrier_key(
&self.barrier_id,
&format!("{}/{}", BARRIER_WORKER, self.worker_id),
);
create_worker_key(client, &key, Some(lease_id))
.await
.map(|_| key)
create_barrier_key(client, key.clone(), data, Some(lease_id)).await?;
Ok(key)
}
async fn wait_for_completion(
......@@ -354,15 +353,15 @@ mod tests {
assert!(drt.etcd_client().is_none());
let barrier = LeaderBarrier::new("test".to_string(), 2, None);
let worker = WorkerBarrier::<()>::new("test".to_string(), "worker".to_string());
let barrier = LeaderBarrier::<String, String>::new("test".to_string(), 2, None);
let worker = WorkerBarrier::<String, String>::new("test".to_string(), "worker".to_string());
assert!(matches!(
barrier.sync(&drt, &"test".to_string()).await,
Err(LeaderWorkerBarrierError::EtcdClientNotFound)
));
assert!(matches!(
worker.sync(&drt).await,
worker.sync(&drt, &"test".to_string()).await,
Err(LeaderWorkerBarrierError::EtcdClientNotFound)
));
}
......@@ -374,19 +373,24 @@ mod tests {
let id = unique_id();
let leader = LeaderBarrier::new(id.clone(), 1, None);
let worker = WorkerBarrier::<String>::new(id.clone(), "worker".to_string());
let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
leader.sync(&drt_clone, &"test_data".to_string()).await?;
let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker".to_string()
);
Ok(())
});
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let res = worker.sync(&drt).await?;
let res = worker.sync(&drt, &"test_worker".to_string()).await?;
assert_eq!(res, "test_data".to_string());
Ok(())
......@@ -405,15 +409,20 @@ mod tests {
let id = unique_id();
let leader1 = LeaderBarrier::new(id.clone(), 1, None);
let leader2 = LeaderBarrier::new(id.clone(), 1, None);
let leader1 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let leader2 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let worker = WorkerBarrier::<String>::new(id.clone(), "worker".to_string());
let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let drt_clone = drt.clone();
let leader1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
leader1.sync(&drt_clone, &"test_data".to_string()).await?;
let worker_data = leader1.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker".to_string()
);
// Now, try to sync leader 2.
let leader2_res = leader2.sync(&drt_clone, &"test_data2".to_string()).await;
......@@ -421,7 +430,7 @@ mod tests {
// Leader 2 should fail because the barrier ID is the same as leader 1.
assert!(matches!(
leader2_res,
Err(LeaderWorkerBarrierError::BarrierIdNotUnique)
Err(LeaderWorkerBarrierError::IdNotUnique)
));
Ok(())
......@@ -429,7 +438,7 @@ mod tests {
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let res = worker.sync(&drt).await?;
let res = worker.sync(&drt, &"test_worker".to_string()).await?;
assert_eq!(res, "test_data".to_string());
Ok(())
......@@ -448,26 +457,33 @@ mod tests {
let id = unique_id();
let leader = LeaderBarrier::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<String>::new(id.clone(), "worker".to_string());
let worker2 = WorkerBarrier::<String>::new(id.clone(), "worker".to_string());
let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let worker2 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
leader.sync(&drt_clone, &"test_data".to_string()).await?;
let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker_1".to_string()
);
Ok(())
});
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
worker1.sync(&drt).await?;
let leader_data = worker1.sync(&drt, &"test_worker_1".to_string()).await?;
assert_eq!(leader_data, "test_data".to_string());
let worker2_res = worker2.sync(&drt).await;
let worker2_res = worker2.sync(&drt, &"test_worker_2".to_string()).await;
assert!(matches!(
worker2_res,
Err(LeaderWorkerBarrierError::BarrierWorkerIdNotUnique)
Err(LeaderWorkerBarrierError::IdNotUnique)
));
Ok(())
......@@ -486,9 +502,9 @@ mod tests {
let id = unique_id();
let leader = LeaderBarrier::new(id.clone(), 2, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<()>::new(id.clone(), "worker2".to_string());
let leader = LeaderBarrier::<(), ()>::new(id.clone(), 2, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());
let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
......@@ -502,7 +518,7 @@ mod tests {
let drt_clone = drt.clone();
let worker1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let res = worker1.sync(&drt_clone).await;
let res = worker1.sync(&drt_clone, &()).await;
assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));
Ok(())
......@@ -511,7 +527,7 @@ mod tests {
let worker2_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
let res = worker2.sync(&drt).await;
let res = worker2.sync(&drt, &()).await;
assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));
Ok(())
......@@ -533,8 +549,9 @@ mod tests {
let id = unique_id();
// Get the leader to send a (), when the worker expects a String.
let leader = LeaderBarrier::new(id.clone(), 1, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<String>::new(id.clone(), "worker1".to_string());
let leader =
LeaderBarrier::<(), String>::new(id.clone(), 1, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker1".to_string());
let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
......@@ -549,7 +566,7 @@ mod tests {
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
assert!(matches!(
worker1.sync(&drt).await,
worker1.sync(&drt, &"test_worker".to_string()).await,
Err(LeaderWorkerBarrierError::SerdeError(_))
));
......@@ -569,9 +586,9 @@ mod tests {
let id = unique_id();
let leader = LeaderBarrier::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<()>::new(id.clone(), "worker2".to_string());
let leader = LeaderBarrier::<(), ()>::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());
let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
......@@ -583,9 +600,9 @@ mod tests {
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let drt_clone = drt.clone();
let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone).await });
let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone, &()).await });
let worker2_join = tokio::spawn(async move { worker2.sync(&drt).await });
let worker2_join = tokio::spawn(async move { worker2.sync(&drt, &()).await });
let (worker1_res, worker2_res) = tokio::join!(worker1_join, worker2_join);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment