Unverified Commit 78826932 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: use atomic transactions when creating etcd kv (#2044)

parent e5a8628f
......@@ -20,7 +20,7 @@ use std::time::Duration;
use crate::{slug::Slug, transports::etcd::Client};
use async_stream::stream;
use async_trait::async_trait;
use etcd_client::{EventType, PutOptions, WatchOptions};
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
......@@ -158,31 +158,44 @@ impl EtcdBucket {
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}");
// Does it already exists? For 'create' it shouldn't.
let kvs = self
.client
.kv_get(k.clone(), None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
if !kvs.is_empty() {
let version = kvs.first().unwrap().version();
return Ok(StorageOutcome::Exists(version as u64));
}
// Use atomic transaction to check and create in one operation
let put_options = PutOptions::new();
// Write it
let mut put_resp = self
// Build transaction that creates key only if it doesn't exist
let txn = Txn::new()
.when(vec![Compare::version(k.as_str(), CompareOp::Equal, 0)]) // Atomic check
.and_then(vec![TxnOp::put(k.as_str(), value, Some(put_options))]) // Only if check passes
.or_else(vec![
TxnOp::get(k.as_str(), None), // Key exists, get its info
]);
// Execute the transaction
let result = self
.client
.kv_put_with_options(k, value, Some(PutOptions::new().with_prev_key()))
.etcd_client()
.kv_client()
.txn(txn)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
// Check if we overwrite something
if put_resp.take_prev_key().is_some() {
// Key created between our get and put
return Err(StorageError::Retry);
if result.succeeded() {
// Key was created successfully
return Ok(StorageOutcome::Created(1)); // version of new key is always 1
}
// version of a new key is always 1
Ok(StorageOutcome::Created(1))
// Key already existed, get its version
if let Some(etcd_client::TxnOpResponse::Get(get_resp)) =
result.op_responses().into_iter().next()
{
if let Some(kv) = get_resp.kvs().first() {
let version = kv.version() as u64;
return Ok(StorageOutcome::Exists(version));
}
}
// Shouldn't happen, but handle edge case
Err(StorageError::EtcdError(
"Unexpected transaction response".to_string(),
))
}
async fn update(
......@@ -241,3 +254,152 @@ fn make_key(bucket_name: &str, key: &str) -> String {
]
.join("/")
}
#[cfg(feature = "integration")]
#[cfg(test)]
mod concurrent_create_tests {
use super::*;
use crate::{distributed::DistributedConfig, DistributedRuntime, Runtime};
use std::sync::Arc;
use tokio::sync::Barrier;
#[test]
fn test_concurrent_etcd_create_race_condition() {
let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false);
rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap();
test_concurrent_create(drt).await.unwrap();
});
}
async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StorageError> {
let etcd_client = drt.etcd_client().expect("etcd client should be available");
let storage = EtcdStorage::new(etcd_client);
// Create a bucket for testing
let bucket = Arc::new(tokio::sync::Mutex::new(
storage
.get_or_create_bucket("test_concurrent_bucket", None)
.await?,
));
// Number of concurrent workers
let num_workers = 10;
let barrier = Arc::new(Barrier::new(num_workers));
// Shared test data
let test_key = format!("concurrent_test_key_{}", uuid::Uuid::new_v4());
let test_value = "test_value";
// Spawn multiple tasks that will all try to create the same key simultaneously
let mut handles = Vec::new();
let success_count = Arc::new(tokio::sync::Mutex::new(0));
let exists_count = Arc::new(tokio::sync::Mutex::new(0));
for worker_id in 0..num_workers {
let bucket_clone = bucket.clone();
let barrier_clone = barrier.clone();
let key_clone = test_key.clone();
let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
let success_count_clone = success_count.clone();
let exists_count_clone = exists_count.clone();
let handle = tokio::spawn(async move {
// Wait for all workers to be ready
barrier_clone.wait().await;
// All workers try to create the same key at the same time
let result = bucket_clone
.lock()
.await
.insert(key_clone, value_clone, 0)
.await;
match result {
Ok(StorageOutcome::Created(version)) => {
println!(
"Worker {} successfully created key with version {}",
worker_id, version
);
let mut count = success_count_clone.lock().await;
*count += 1;
Ok(version)
}
Ok(StorageOutcome::Exists(version)) => {
println!(
"Worker {} found key already exists with version {}",
worker_id, version
);
let mut count = exists_count_clone.lock().await;
*count += 1;
Ok(version)
}
Err(e) => {
println!("Worker {} got error: {:?}", worker_id, e);
Err(e)
}
}
});
handles.push(handle);
}
// Wait for all workers to complete
let mut results = Vec::new();
for handle in handles {
let result = handle.await.unwrap();
if let Ok(version) = result {
results.push(version);
}
}
// Verify results
let final_success_count = *success_count.lock().await;
let final_exists_count = *exists_count.lock().await;
println!(
"Final counts - Created: {}, Exists: {}",
final_success_count, final_exists_count
);
// CRITICAL ASSERTIONS:
// 1. Exactly ONE worker should have successfully created the key
assert_eq!(
final_success_count, 1,
"Exactly one worker should create the key"
);
// 2. All other workers should have gotten "Exists" response
assert_eq!(
final_exists_count,
num_workers - 1,
"All other workers should see key exists"
);
// 3. Total successful operations should equal number of workers
assert_eq!(
results.len(),
num_workers,
"All workers should complete successfully"
);
// 4. Verify the key actually exists in etcd
let stored_value = bucket.lock().await.get(&test_key).await?;
assert!(stored_value.is_some(), "Key should exist in etcd");
// 5. The stored value should be from one of the workers
let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
assert!(
stored_str.starts_with(test_value),
"Stored value should match expected prefix"
);
// Clean up
bucket.lock().await.delete(&test_key).await?;
Ok(())
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment