address.rs 14.1 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Address types for peer discovery.
//!
//! This module provides types for representing worker addresses and peer information:
//! - [`WorkerAddress`]: Opaque byte representation of a peer's network address
//! - [`PeerInfo`]: Combined instance ID and worker address for a discovered peer
//!
//! These types are intentionally transport-agnostic, storing addresses as opaque bytes.
//! The interpretation of these bytes is left to the active message runtime.

use crate::identity::{InstanceId, WorkerId};
use crate::transport::TransportKey;

use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use xxhash_rust::xxh3::xxh3_64;

/// Errors that can occur when working with WorkerAddress.
#[derive(Debug, thiserror::Error)]
pub enum WorkerAddressError {
    /// Attempted to add a key that already exists
    #[error("Key already exists: {0}")]
    KeyExists(String),

    /// Attempted to access or remove a key that doesn't exist
    #[error("Key not found: {0}")]
    KeyNotFound(String),

    /// Failed to encode the map to bytes
    #[error("Encoding error: {0}")]
    EncodingError(#[from] rmp_serde::encode::Error),

    /// Failed to decode bytes to map
    #[error("Decoding error: {0}")]
    DecodingError(#[from] rmp_serde::decode::Error),

    /// Encountered an unsupported format version
    #[error("Unsupported format version: {0}")]
    UnsupportedVersion(u8),

    /// The data format is invalid
    #[error("Invalid format: {0}")]
    InvalidFormat(String),
}

/// Opaque worker address for discovery.
///
/// This is a transport-agnostic representation of a peer's network address.
/// The bytes are opaque to discovery and are interpreted by the active message runtime.
///
/// # Checksum
///
/// WorkerAddress implements a checksum via xxh3_64 for quick comparison during
/// re-registration validation.
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct WorkerAddress(Bytes);

// Custom Serialize/Deserialize to handle Bytes
impl Serialize for WorkerAddress {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        serde_bytes::serialize(self.0.as_ref(), serializer)
    }
}

impl<'de> Deserialize<'de> for WorkerAddress {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let bytes: Vec<u8> = serde_bytes::deserialize(deserializer)?;
        Ok(WorkerAddress(Bytes::from(bytes)))
    }
}

impl WorkerAddress {
    /// Create a WorkerAddress from pre-encoded bytes.
    ///
    /// This is used by transport implementations to construct addresses from
    /// MessagePack-encoded map data. The bytes are assumed to be valid MessagePack.
    pub fn from_encoded(bytes: impl Into<Bytes>) -> Self {
        Self(bytes.into())
    }

    /// Get the underlying bytes.
    pub fn as_bytes(&self) -> &[u8] {
        &self.0
    }

    /// Get the bytes as a Bytes object.
    pub fn to_bytes(&self) -> Bytes {
        self.0.clone()
    }

    /// Compute a checksum of this address for validation.
    ///
    /// This is used to quickly check if an address has changed during re-registration.
    pub fn checksum(&self) -> u64 {
        xxh3_64(self.as_bytes())
    }

    /// Get the list of available transport keys in this address.
    ///
    /// Returns the keys from the internal map as `TransportKey` for type-safe efficient
    /// storage and sharing. This allows callers to see what transport types or endpoints
    /// are available without exposing the full map.
    ///
    /// # Errors
    ///
    /// Returns an error if the internal bytes cannot be decoded as a valid MessagePack map.
    ///
    /// # Example
    ///
    /// ```no_run
    /// # use velo_common::{WorkerAddress, TransportKey};
    /// # let address: WorkerAddress = unimplemented!();
    /// let transports = address.available_transports().unwrap();
    /// if transports.contains(&TransportKey::from("tcp")) {
    ///     // TCP transport is available
    /// }
    /// ```
    pub fn available_transports(&self) -> Result<Vec<TransportKey>, WorkerAddressError> {
        let map = decode_to_map(self.as_bytes())?;
        Ok(map.keys().cloned().map(TransportKey::from).collect())
    }

    /// Get a single entry from the internal map.
    ///
    /// This decodes the address and extracts the entry for the given key.
    ///
    /// Accepts any type that can be converted to a string reference, including
    /// `&str`, `String`, `&String`, and `TransportKey`.
    ///
    /// # Errors
    ///
    /// Returns an error if the internal bytes cannot be decoded as a valid MessagePack map.
    pub fn get_entry(&self, key: impl AsRef<str>) -> Result<Option<Bytes>, WorkerAddressError> {
        let map = decode_to_map(self.as_bytes())?;
        Ok(map.get(key.as_ref()).cloned())
    }
}

impl fmt::Debug for WorkerAddress {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_tuple("WorkerAddress")
            .field(&format_args!(
                "len={}, xxh3_64=0x{:016x}",
                self.0.len(),
                self.checksum()
            ))
            .finish()
    }
}

impl fmt::Display for WorkerAddress {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "WorkerAddress(xxh3_64=0x{:016x})", self.checksum())
    }
}

// ============================================================================
// Internal Decoding Helper
// ============================================================================

/// Decode WorkerAddress bytes from MessagePack into a map.
fn decode_to_map(bytes: &[u8]) -> Result<HashMap<Arc<str>, Bytes>, WorkerAddressError> {
    if bytes.is_empty() {
        return Err(WorkerAddressError::InvalidFormat("Empty bytes".to_string()));
    }

    // Decode MessagePack
    let decoded: HashMap<String, Vec<u8>> = rmp_serde::from_slice(bytes)?;

    // Convert to HashMap<Arc<str>, Bytes>
    Ok(decoded
        .into_iter()
        .map(|(k, v)| (Arc::from(k.as_str()), Bytes::from(v)))
        .collect())
}

/// Peer information combining instance ID and worker address.
///
/// This is the primary type returned by discovery lookups. It contains everything
/// needed to connect to and identify a peer.
///
/// # Example
///
/// ```no_run
/// # // WorkerAddress is created internally, this is simplified for docs
/// use velo_common::{InstanceId, PeerInfo};
/// # use velo_common::WorkerAddress;
/// # let address: WorkerAddress = unimplemented!();
///
/// let instance_id = InstanceId::new_v4();
/// let peer_info = PeerInfo::new(instance_id, address);
///
/// assert_eq!(peer_info.instance_id(), instance_id);
/// assert_eq!(peer_info.worker_id(), instance_id.worker_id());
/// ```
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PeerInfo {
    /// The instance ID of the peer
    pub instance_id: InstanceId,
    /// The worker address for connecting to the peer
    pub worker_address: WorkerAddress,
}

impl PeerInfo {
    /// Create a new PeerInfo.
    pub fn new(instance_id: InstanceId, worker_address: WorkerAddress) -> Self {
        Self {
            instance_id,
            worker_address,
        }
    }

    /// Get the instance ID.
    pub fn instance_id(&self) -> InstanceId {
        self.instance_id
    }

    /// Get the worker ID (derived from instance ID).
    pub fn worker_id(&self) -> WorkerId {
        self.instance_id.worker_id()
    }

    /// Get a reference to the worker address.
    pub fn worker_address(&self) -> &WorkerAddress {
        &self.worker_address
    }

    /// Get the worker address checksum for validation.
    pub fn address_checksum(&self) -> u64 {
        self.worker_address.checksum()
    }

    /// Consume self and return the worker address.
    pub fn into_address(self) -> WorkerAddress {
        self.worker_address
    }

    /// Decompose into instance ID and worker address.
    pub fn into_parts(self) -> (InstanceId, WorkerAddress) {
        (self.instance_id, self.worker_address)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // Helper to create a test address with MessagePack encoding
    fn make_test_address(entries: &[(&str, &[u8])]) -> WorkerAddress {
        let map: HashMap<String, Vec<u8>> = entries
            .iter()
            .map(|(k, v)| (k.to_string(), v.to_vec()))
            .collect();
        let encoded = rmp_serde::to_vec(&map).unwrap();
        WorkerAddress::from_encoded(encoded)
    }

    #[test]
    fn test_worker_address_from_encoded() {
        let address = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);

        // Verify we can get the entry back
        let entry = address.get_entry("endpoint").unwrap();
        assert_eq!(entry, Some(Bytes::from_static(b"tcp://127.0.0.1:5555")));
    }

    #[test]
    fn test_worker_address_checksum() {
        let address1 = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);
        let address2 = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);
        let address3 = make_test_address(&[("endpoint", b"tcp://127.0.0.1:6666")]);

        // Same content = same checksum
        assert_eq!(address1.checksum(), address2.checksum());

        // Different content = different checksum
        assert_ne!(address1.checksum(), address3.checksum());
    }

    #[test]
    fn test_worker_address_equality() {
        let address1 = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);
        let address2 = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);
        let address3 = make_test_address(&[("endpoint", b"tcp://127.0.0.1:6666")]);

        assert_eq!(address1, address2);
        assert_ne!(address1, address3);
    }

    #[test]
    fn test_worker_address_debug() {
        let address = make_test_address(&[("test", b"value")]);
        let debug_str = format!("{:?}", address);

        assert!(debug_str.contains("WorkerAddress"));
        assert!(debug_str.contains("len="));
        assert!(debug_str.contains("xxh3_64="));
    }

    #[test]
    fn test_available_transports() {
        let address = make_test_address(&[
            ("tcp", b"tcp://127.0.0.1:5555"),
            ("rdma", b"rdma://10.0.0.1:6666"),
            ("udp", b"udp://127.0.0.1:7777"),
        ]);

        let transports = address.available_transports().unwrap();
        assert_eq!(transports.len(), 3);
        assert!(transports.contains(&TransportKey::from("tcp")));
        assert!(transports.contains(&TransportKey::from("rdma")));
        assert!(transports.contains(&TransportKey::from("udp")));
    }

    #[test]
    fn test_available_transports_empty() {
        let address = make_test_address(&[]);
        let transports = address.available_transports().unwrap();
        assert_eq!(transports.len(), 0);
    }

    #[test]
    fn test_get_entry() {
        let address =
            make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555"), ("protocol", b"tcp")]);

        // Get existing entry
        assert_eq!(
            address.get_entry("endpoint").unwrap().unwrap(),
            Bytes::from_static(b"tcp://127.0.0.1:5555")
        );

        // Get nonexistent entry
        assert!(address.get_entry("nonexistent").unwrap().is_none());
    }

    #[test]
    fn test_get_entry_with_transport_key() {
        let address = make_test_address(&[
            ("tcp", b"tcp://127.0.0.1:5555"),
            ("rdma", b"rdma://10.0.0.1:6666"),
        ]);

        // Test get_entry with TransportKey
        let tcp_key = TransportKey::from("tcp");
        let result = address.get_entry(tcp_key).unwrap();
        assert_eq!(result, Some(Bytes::from_static(b"tcp://127.0.0.1:5555")));

        // Test get_entry with String
        let result = address.get_entry(String::from("rdma")).unwrap();
        assert_eq!(result, Some(Bytes::from_static(b"rdma://10.0.0.1:6666")));
    }

    #[test]
    fn test_peer_info_creation() {
        let instance_id = InstanceId::new_v4();
        let address = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);

        let peer_info = PeerInfo::new(instance_id, address.clone());

        assert_eq!(peer_info.instance_id(), instance_id);
        assert_eq!(peer_info.worker_id(), instance_id.worker_id());
        assert_eq!(peer_info.worker_address(), &address);
    }

    #[test]
    fn test_peer_info_checksum() {
        let instance_id = InstanceId::new_v4();
        let address = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);

        let peer_info = PeerInfo::new(instance_id, address.clone());

        assert_eq!(peer_info.address_checksum(), address.checksum());
    }

    #[test]
    fn test_peer_info_into_address() {
        let instance_id = InstanceId::new_v4();
        let address = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);

        let peer_info = PeerInfo::new(instance_id, address.clone());
        let extracted_address = peer_info.into_address();

        assert_eq!(extracted_address, address);
    }

    #[test]
    fn test_peer_info_into_parts() {
        let instance_id = InstanceId::new_v4();
        let address = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);

        let peer_info = PeerInfo::new(instance_id, address.clone());
        let (extracted_id, extracted_address) = peer_info.into_parts();

        assert_eq!(extracted_id, instance_id);
        assert_eq!(extracted_address, address);
    }

    #[test]
    fn test_peer_info_serde() {
        let instance_id = InstanceId::new_v4();
        let address = make_test_address(&[("endpoint", b"tcp://127.0.0.1:5555")]);
        let peer_info = PeerInfo::new(instance_id, address);

        // Serialize to JSON
        let json = serde_json::to_string(&peer_info).unwrap();

        // Deserialize back
        let deserialized: PeerInfo = serde_json::from_str(&json).unwrap();

        assert_eq!(deserialized.instance_id(), instance_id);
        assert_eq!(deserialized.worker_id(), instance_id.worker_id());

        // Verify the entry is preserved
        let entry = deserialized.worker_address().get_entry("endpoint").unwrap();
        assert_eq!(entry, Some(Bytes::from_static(b"tcp://127.0.0.1:5555")));
    }
}