Unverified Commit b6852919 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

fix: handle early arrivals (#7126)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent 236cb17d
...@@ -42,3 +42,4 @@ testing-nixl-gds = [] ...@@ -42,3 +42,4 @@ testing-nixl-gds = []
[dev-dependencies] [dev-dependencies]
dynamo-memory = { workspace = true, features = ["unsafe-slices"] } dynamo-memory = { workspace = true, features = ["unsafe-slices"] }
rstest = "0.26" rstest = "0.26"
tokio = { workspace = true, features = ["test-util"] }
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
//! NIXL notification-based completion handler. //! NIXL notification-based completion handler.
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
...@@ -21,10 +21,48 @@ pub struct RegisterNixlNotification { ...@@ -21,10 +21,48 @@ pub struct RegisterNixlNotification {
pub event_handle: EventHandle, pub event_handle: EventHandle,
} }
/// Tracking struct for outstanding NIXL notification transfers. // ── Abstractions for testability ────────────────────────────────────
struct OutstandingTransfer {
#[allow(dead_code)] // Kept for potential future cleanup or debugging /// Trait abstracting notification polling (wraps NixlAgent + NotificationMap).
xfer_req: XferRequest, trait NotificationSource: Send {
fn poll_notifications(&self) -> anyhow::Result<HashMap<String, Vec<String>>>;
}
/// Production implementation backed by a real NIXL agent.
struct NixlNotificationSource {
agent: NixlAgent,
}
impl NotificationSource for NixlNotificationSource {
fn poll_notifications(&self) -> anyhow::Result<HashMap<String, Vec<String>>> {
let mut notif_map = NotificationMap::new()?;
self.agent.get_notifications(&mut notif_map, None)?;
Ok(notif_map.take_notifs()?)
}
}
/// Trait for registration messages, abstracting away `XferRequest` for testability.
///
/// `Payload` is kept alive in the outstanding map for the duration of the transfer
/// (e.g. `XferRequest` must not be dropped while the NIXL transfer is in flight).
trait Registration: Send {
type Payload: Send;
fn decompose(self) -> (Uuid, EventHandle, Self::Payload);
}
impl Registration for RegisterNixlNotification {
type Payload = XferRequest;
fn decompose(self) -> (Uuid, EventHandle, XferRequest) {
(self.uuid, self.event_handle, self.xfer_req)
}
}
// ── Internal types ──────────────────────────────────────────────────
/// Tracking struct for outstanding transfers.
struct OutstandingTransfer<T> {
#[allow(dead_code)] // Kept alive for the duration of the transfer
_payload: T,
event_handle: EventHandle, event_handle: EventHandle,
arrived_at: Instant, arrived_at: Instant,
last_warned_at: Option<Instant>, last_warned_at: Option<Instant>,
...@@ -55,14 +93,70 @@ fn check_and_warn_slow_transfer( ...@@ -55,14 +93,70 @@ fn check_and_warn_slow_transfer(
last_warned_at last_warned_at
} }
/// NIXL notification-based transfer completion handler. // ── Shared helpers ──────────────────────────────────────────────────
/// Fetches notifications in batches and matches them against outstanding transfers.
pub async fn process_nixl_notification_events( /// Parse notification UUIDs and collect those matching outstanding transfers.
agent: NixlAgent, /// Unknown UUIDs are optionally routed to `early_arrivals` with a warning.
mut rx: mpsc::Receiver<RegisterNixlNotification>, fn collect_completed<T>(
notifications: HashMap<String, Vec<String>>,
outstanding: &HashMap<Uuid, OutstandingTransfer<T>>,
mut early_arrivals: Option<&mut HashSet<Uuid>>,
) -> Vec<Uuid> {
let mut completed = Vec::new();
for (_agent_name, notif_strings) in notifications {
for notif_str in notif_strings {
if let Ok(notif_uuid) = Uuid::parse_str(&notif_str) {
if outstanding.contains_key(&notif_uuid) {
completed.push(notif_uuid);
} else if let Some(early) = early_arrivals.as_deref_mut() {
warn!(
uuid = %notif_uuid,
"Notification arrived for transfer not in outstanding map (early arrival)"
);
early.insert(notif_uuid);
#[cfg(all(not(test), debug_assertions))]
panic!(
"Notification arrived for transfer not in outstanding map (early arrival); ensure all transfers are registered the NIXL notification can be triggered"
);
}
}
}
}
completed
}
/// Remove completed transfers from the outstanding map and trigger their events.
fn complete_transfers<T>(
completed: Vec<Uuid>,
outstanding: &mut HashMap<Uuid, OutstandingTransfer<T>>,
system: &EventManager,
) {
for uuid in completed {
if let Some(transfer) = outstanding.remove(&uuid)
&& let Err(e) = system.trigger(transfer.event_handle)
{
error!(
uuid = %uuid,
error = %e,
"Failed to trigger completion event"
);
}
}
}
// ── Core processing loop ────────────────────────────────────────────
/// Generic notification event loop, parameterized over the notification source
/// and registration message type for testability.
async fn process_events_core<S: NotificationSource, R: Registration>(
source: S,
mut rx: mpsc::Receiver<R>,
system: Arc<EventManager>, system: Arc<EventManager>,
) { ) {
let mut outstanding: HashMap<Uuid, OutstandingTransfer> = HashMap::new(); let mut outstanding: HashMap<Uuid, OutstandingTransfer<R::Payload>> = HashMap::new();
let mut early_arrivals: HashSet<Uuid> = HashSet::new();
let mut last_early_arrival_warn: Option<Instant> = None;
let mut check_interval = interval(Duration::from_millis(1)); let mut check_interval = interval(Duration::from_millis(1));
loop { loop {
...@@ -70,14 +164,28 @@ pub async fn process_nixl_notification_events( ...@@ -70,14 +164,28 @@ pub async fn process_nixl_notification_events(
// Handle new transfer requests // Handle new transfer requests
notification = rx.recv() => { notification = rx.recv() => {
match notification { match notification {
Some(notif) => { Some(reg) => {
outstanding.insert(notif.uuid, OutstandingTransfer { let (uuid, event_handle, payload) = reg.decompose();
xfer_req: notif.xfer_req, if early_arrivals.remove(&uuid) {
event_handle: notif.event_handle, // Notification arrived before registration — complete immediately.
// Payload is dropped here; the transfer is already done.
drop(payload);
if let Err(e) = system.trigger(event_handle) {
error!(
uuid = %uuid,
error = %e,
"Failed to trigger completion event for early arrival"
);
}
} else {
outstanding.insert(uuid, OutstandingTransfer {
_payload: payload,
event_handle,
arrived_at: Instant::now(), arrived_at: Instant::now(),
last_warned_at: None, last_warned_at: None,
}); });
} }
}
None => { None => {
// Channel closed, finish processing outstanding transfers then exit // Channel closed, finish processing outstanding transfers then exit
break; break;
...@@ -87,52 +195,19 @@ pub async fn process_nixl_notification_events( ...@@ -87,52 +195,19 @@ pub async fn process_nixl_notification_events(
// Periodically fetch and process notifications // Periodically fetch and process notifications
_ = check_interval.tick(), if !outstanding.is_empty() => { _ = check_interval.tick(), if !outstanding.is_empty() => {
// Create notification map inside this branch to avoid Send issues let notifications = match source.poll_notifications() {
let mut notif_map = match NotificationMap::new() {
Ok(map) => map,
Err(e) => {
warn!(error = %e, "Failed to create notification map");
continue;
}
};
// Fetch all pending notifications
if let Err(e) = agent.get_notifications(&mut notif_map, None) {
warn!(error = %e, "Failed to fetch NIXL notifications");
continue;
}
// Process notifications and match against outstanding transfers
let notifications = match notif_map.take_notifs() {
Ok(notifs) => notifs, Ok(notifs) => notifs,
Err(e) => { Err(e) => {
warn!(error = %e, "Failed to extract notifications from map"); warn!(error = %e, "Failed to fetch notifications");
continue; continue;
} }
}; };
let mut completed = Vec::new(); let completed = collect_completed(
notifications,
// Iterate through all notifications &outstanding,
for (_agent_name, notif_strings) in notifications { Some(&mut early_arrivals),
for notif_str in notif_strings {
// Try to parse notification as UUID
// NOTE: This assumes notifications contain UUIDs.
// The actual format may be different and may need adjustment.
if let Ok(notif_uuid) = Uuid::parse_str(&notif_str) {
if outstanding.contains_key(&notif_uuid) {
completed.push(notif_uuid);
} else {
// Notification arrived before we started waiting for it
// This is the race condition we need to handle
warn!(
uuid = %notif_uuid,
"Received notification for transfer not in outstanding map (early arrival)"
); );
}
}
}
}
// Check for slow transfers and update warnings // Check for slow transfers and update warnings
for (uuid, transfer) in outstanding.iter_mut() { for (uuid, transfer) in outstanding.iter_mut() {
...@@ -145,17 +220,21 @@ pub async fn process_nixl_notification_events( ...@@ -145,17 +220,21 @@ pub async fn process_nixl_notification_events(
} }
} }
// Remove completed transfers and signal completion // Warn periodically if early_arrivals has unmatched entries
for uuid in completed { if !early_arrivals.is_empty() {
if let Some(transfer) = outstanding.remove(&uuid) let should_warn = last_early_arrival_warn
&& let Err(e) = system.trigger(transfer.event_handle) { .map(|t| t.elapsed() > Duration::from_secs(30))
error!( .unwrap_or(true);
uuid = %uuid, if should_warn {
error = %e, warn!(
"Failed to trigger completion event" count = early_arrivals.len(),
"early_arrivals buffer has unmatched entries"
); );
last_early_arrival_warn = Some(Instant::now());
} }
} }
complete_transfers(completed, &mut outstanding, &system);
} }
} }
} }
...@@ -165,31 +244,498 @@ pub async fn process_nixl_notification_events( ...@@ -165,31 +244,498 @@ pub async fn process_nixl_notification_events(
while !outstanding.is_empty() { while !outstanding.is_empty() {
check_interval.tick().await; check_interval.tick().await;
let mut notif_map = match NotificationMap::new() { match source.poll_notifications() {
Ok(map) => map, Ok(notifications) => {
Err(_) => continue, let completed = collect_completed(notifications, &outstanding, None);
}; complete_transfers(completed, &mut outstanding, &system);
}
Err(e) => {
warn!(error = %e, "Failed to fetch notifications during shutdown drain");
}
}
}
}
if let Ok(()) = agent.get_notifications(&mut notif_map, None) // ── Public API ──────────────────────────────────────────────────────
&& let Ok(notifications) = notif_map.take_notifs()
{
let mut completed = Vec::new();
for (_agent_name, notif_strings) in notifications { /// NIXL notification-based transfer completion handler.
for notif_str in notif_strings { /// Fetches notifications in batches and matches them against outstanding transfers.
if let Ok(notif_uuid) = Uuid::parse_str(&notif_str) pub async fn process_nixl_notification_events(
&& outstanding.contains_key(&notif_uuid) agent: NixlAgent,
{ rx: mpsc::Receiver<RegisterNixlNotification>,
completed.push(notif_uuid); system: Arc<EventManager>,
) {
process_events_core(NixlNotificationSource { agent }, rx, system).await
}
// ── Tests ───────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::Mutex;
use tokio::task::yield_now;
use velo_events::EventStatus;
// ── Mock notification source ────────────────────────────────────
type NotificationQueue = Arc<Mutex<VecDeque<anyhow::Result<HashMap<String, Vec<String>>>>>>;
struct MockNotificationSource {
queue: NotificationQueue,
}
struct MockControl {
queue: NotificationQueue,
} }
fn mock_source() -> (MockNotificationSource, MockControl) {
let queue = Arc::new(Mutex::new(VecDeque::new()));
(
MockNotificationSource {
queue: queue.clone(),
},
MockControl { queue },
)
}
impl NotificationSource for MockNotificationSource {
fn poll_notifications(&self) -> anyhow::Result<HashMap<String, Vec<String>>> {
self.queue
.lock()
.unwrap()
.pop_front()
.unwrap_or(Ok(HashMap::new()))
} }
} }
for uuid in completed { impl MockControl {
if let Some(transfer) = outstanding.remove(&uuid) { fn push_notification(&self, uuid: Uuid) {
let _ = system.trigger(transfer.event_handle); let mut map = HashMap::new();
map.insert("test_agent".to_string(), vec![uuid.to_string()]);
self.queue.lock().unwrap().push_back(Ok(map));
}
fn push_error(&self) {
self.queue
.lock()
.unwrap()
.push_back(Err(anyhow::anyhow!("mock poll error")));
}
}
// ── Test registration type (no XferRequest needed) ──────────────
struct TestRegistration {
uuid: Uuid,
event_handle: EventHandle,
}
impl Registration for TestRegistration {
type Payload = ();
fn decompose(self) -> (Uuid, EventHandle, ()) {
(self.uuid, self.event_handle, ())
}
}
// ── Helpers ─────────────────────────────────────────────────────
/// Advance paused tokio time and yield so spawned tasks can process.
async fn tick() {
tokio::time::advance(Duration::from_millis(2)).await;
yield_now().await;
} }
// ── check_and_warn_slow_transfer unit tests ─────────────────────
#[test]
fn warn_under_threshold_returns_none() {
let result = check_and_warn_slow_transfer(
&Uuid::new_v4(),
Instant::now() - Duration::from_secs(30),
None,
);
assert!(result.is_none());
} }
#[test]
fn warn_over_threshold_first_time() {
let before = Instant::now();
let result = check_and_warn_slow_transfer(
&Uuid::new_v4(),
Instant::now() - Duration::from_secs(61),
None,
);
let after = Instant::now();
let t = result.expect("should have warned");
assert!(t >= before && t <= after);
} }
#[test]
fn warn_throttled_within_30s() {
let last = Instant::now() - Duration::from_secs(10);
let result = check_and_warn_slow_transfer(
&Uuid::new_v4(),
Instant::now() - Duration::from_secs(90),
Some(last),
);
assert_eq!(result, Some(last));
}
#[test]
fn warn_throttle_expired_after_30s() {
let before = Instant::now();
let result = check_and_warn_slow_transfer(
&Uuid::new_v4(),
Instant::now() - Duration::from_secs(120),
Some(Instant::now() - Duration::from_secs(35)),
);
let t = result.expect("should have re-warned");
assert!(t >= before);
}
// ── Integration tests for process_events_core ───────────────────
#[tokio::test(start_paused = true)]
async fn normal_flow_register_then_notify() -> anyhow::Result<()> {
let (source, control) = mock_source();
let (tx, rx) = mpsc::channel(16);
let system = Arc::new(EventManager::local());
let uuid = Uuid::new_v4();
let event = system.new_event()?;
let handle = event.into_handle();
let task = tokio::spawn(process_events_core(source, rx, system.clone()));
// Register transfer
tx.send(TestRegistration {
uuid,
event_handle: handle,
})
.await?;
yield_now().await;
// Deliver notification
control.push_notification(uuid);
tick().await;
assert_eq!(system.poll(handle)?, EventStatus::Ready);
drop(tx);
task.await?;
Ok(())
}
#[tokio::test(start_paused = true)]
async fn early_arrival_notify_before_register() -> anyhow::Result<()> {
let (source, control) = mock_source();
let (tx, rx) = mpsc::channel(16);
let system = Arc::new(EventManager::local());
let uuid_early = Uuid::new_v4();
let event_early = system.new_event()?;
let handle_early = event_early.into_handle();
// Filler transfer to make the tick branch fire
let uuid_filler = Uuid::new_v4();
let event_filler = system.new_event()?;
let handle_filler = event_filler.into_handle();
let task = tokio::spawn(process_events_core(source, rx, system.clone()));
// Register filler so outstanding is non-empty
tx.send(TestRegistration {
uuid: uuid_filler,
event_handle: handle_filler,
})
.await?;
yield_now().await;
// Notification arrives for uuid_early before it's registered (early arrival)
control.push_notification(uuid_early);
tick().await;
// Now register uuid_early — should complete immediately from early_arrivals
tx.send(TestRegistration {
uuid: uuid_early,
event_handle: handle_early,
})
.await?;
yield_now().await;
assert_eq!(system.poll(handle_early)?, EventStatus::Ready);
// Cleanup filler
control.push_notification(uuid_filler);
tick().await;
drop(tx);
task.await?;
Ok(())
}
#[tokio::test(start_paused = true)]
async fn multiple_transfers_complete_independently() -> anyhow::Result<()> {
let (source, control) = mock_source();
let (tx, rx) = mpsc::channel(16);
let system = Arc::new(EventManager::local());
let uuid_a = Uuid::new_v4();
let event_a = system.new_event()?;
let handle_a = event_a.into_handle();
let uuid_b = Uuid::new_v4();
let event_b = system.new_event()?;
let handle_b = event_b.into_handle();
let task = tokio::spawn(process_events_core(source, rx, system.clone()));
// Register both
tx.send(TestRegistration {
uuid: uuid_a,
event_handle: handle_a,
})
.await?;
tx.send(TestRegistration {
uuid: uuid_b,
event_handle: handle_b,
})
.await?;
yield_now().await;
// Complete B only
control.push_notification(uuid_b);
tick().await;
assert_eq!(system.poll(handle_b)?, EventStatus::Ready);
assert_eq!(system.poll(handle_a)?, EventStatus::Pending);
// Complete A
control.push_notification(uuid_a);
tick().await;
assert_eq!(system.poll(handle_a)?, EventStatus::Ready);
drop(tx);
task.await?;
Ok(())
}
#[tokio::test(start_paused = true)]
async fn mixed_ordering_early_and_normal() -> anyhow::Result<()> {
let (source, control) = mock_source();
let (tx, rx) = mpsc::channel(16);
let system = Arc::new(EventManager::local());
let uuid_early = Uuid::new_v4();
let event_early = system.new_event()?;
let handle_early = event_early.into_handle();
let uuid_normal = Uuid::new_v4();
let event_normal = system.new_event()?;
let handle_normal = event_normal.into_handle();
let task = tokio::spawn(process_events_core(source, rx, system.clone()));
// Register normal first (makes outstanding non-empty for tick)
tx.send(TestRegistration {
uuid: uuid_normal,
event_handle: handle_normal,
})
.await?;
yield_now().await;
// Early arrival for uuid_early
control.push_notification(uuid_early);
tick().await;
// Register uuid_early — triggers immediately
tx.send(TestRegistration {
uuid: uuid_early,
event_handle: handle_early,
})
.await?;
yield_now().await;
assert_eq!(system.poll(handle_early)?, EventStatus::Ready);
assert_eq!(system.poll(handle_normal)?, EventStatus::Pending);
// Complete normal
control.push_notification(uuid_normal);
tick().await;
assert_eq!(system.poll(handle_normal)?, EventStatus::Ready);
drop(tx);
task.await?;
Ok(())
}
#[tokio::test(start_paused = true)]
async fn channel_close_drains_outstanding() -> anyhow::Result<()> {
let (source, control) = mock_source();
let (tx, rx) = mpsc::channel(16);
let system = Arc::new(EventManager::local());
let uuid = Uuid::new_v4();
let event = system.new_event()?;
let handle = event.into_handle();
let task = tokio::spawn(process_events_core(source, rx, system.clone()));
// Register
tx.send(TestRegistration {
uuid,
event_handle: handle,
})
.await?;
yield_now().await;
// Close channel — enters drain loop
drop(tx);
yield_now().await;
// Deliver notification during drain
control.push_notification(uuid);
tick().await;
assert_eq!(system.poll(handle)?, EventStatus::Ready);
task.await?;
Ok(())
}
#[tokio::test(start_paused = true)]
async fn duplicate_notification_is_harmless() -> anyhow::Result<()> {
let (source, control) = mock_source();
let (tx, rx) = mpsc::channel(16);
let system = Arc::new(EventManager::local());
let uuid = Uuid::new_v4();
let event = system.new_event()?;
let handle = event.into_handle();
let task = tokio::spawn(process_events_core(source, rx, system.clone()));
// Register
tx.send(TestRegistration {
uuid,
event_handle: handle,
})
.await?;
yield_now().await;
// Deliver same notification twice (two separate polls)
control.push_notification(uuid);
tick().await;
assert_eq!(system.poll(handle)?, EventStatus::Ready);
// Second notification — UUID no longer in outstanding, goes to early_arrivals
// This should not panic or cause issues
control.push_notification(uuid);
// Need another outstanding transfer for the tick to fire
let uuid2 = Uuid::new_v4();
let event2 = system.new_event()?;
let handle2 = event2.into_handle();
tx.send(TestRegistration {
uuid: uuid2,
event_handle: handle2,
})
.await?;
yield_now().await;
tick().await;
// Original event still Ready, no double-trigger issues
assert_eq!(system.poll(handle)?, EventStatus::Ready);
// Cleanup
control.push_notification(uuid2);
tick().await;
drop(tx);
task.await?;
Ok(())
}
#[tokio::test(start_paused = true)]
async fn poll_error_does_not_crash_loop() -> anyhow::Result<()> {
let (source, control) = mock_source();
let (tx, rx) = mpsc::channel(16);
let system = Arc::new(EventManager::local());
let uuid = Uuid::new_v4();
let event = system.new_event()?;
let handle = event.into_handle();
let task = tokio::spawn(process_events_core(source, rx, system.clone()));
// Register
tx.send(TestRegistration {
uuid,
event_handle: handle,
})
.await?;
yield_now().await;
// Inject a poll error
control.push_error();
tick().await;
// Event should still be pending (error was swallowed, not fatal)
assert_eq!(system.poll(handle)?, EventStatus::Pending);
// Now deliver the real notification — handler recovered
control.push_notification(uuid);
tick().await;
assert_eq!(system.poll(handle)?, EventStatus::Ready);
drop(tx);
task.await?;
Ok(())
}
#[tokio::test(start_paused = true)]
async fn unknown_notification_in_drain_is_ignored() -> anyhow::Result<()> {
let (source, control) = mock_source();
let (tx, rx) = mpsc::channel(16);
let system = Arc::new(EventManager::local());
let uuid = Uuid::new_v4();
let event = system.new_event()?;
let handle = event.into_handle();
let task = tokio::spawn(process_events_core(source, rx, system.clone()));
// Register
tx.send(TestRegistration {
uuid,
event_handle: handle,
})
.await?;
yield_now().await;
// Close channel to enter drain
drop(tx);
yield_now().await;
// Deliver both our UUID and an unknown one — should not panic
let mut map = HashMap::new();
map.insert(
"test_agent".to_string(),
vec![uuid.to_string(), Uuid::new_v4().to_string()],
);
control.queue.lock().unwrap().push_back(Ok(map));
tick().await;
assert_eq!(system.poll(handle)?, EventStatus::Ready);
task.await?;
Ok(())
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment