Unverified Commit 1208f017 authored by akshaver's avatar akshaver Committed by GitHub
Browse files

chore: add loopback as default address for network (#3250)

parent 65071f21
...@@ -5,7 +5,7 @@ use core::panic; ...@@ -5,7 +5,7 @@ use core::panic;
use socket2::{Domain, SockAddr, Socket, Type}; use socket2::{Domain, SockAddr, Socket, Type};
use std::{ use std::{
collections::HashMap, collections::HashMap,
net::{SocketAddr, TcpListener}, net::{IpAddr, SocketAddr, TcpListener},
os::fd::{AsFd, FromRawFd}, os::fd::{AsFd, FromRawFd},
sync::Arc, sync::Arc,
}; };
...@@ -15,6 +15,26 @@ use bytes::Bytes; ...@@ -15,6 +15,26 @@ use bytes::Bytes;
use derive_builder::Builder; use derive_builder::Builder;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6}; use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
// Trait for IP address resolution - allows dependency injection for testing
pub trait IpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
}
// Default implementation using the real local_ip_address crate
pub struct DefaultIpResolver;
impl IpResolver for DefaultIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
local_ip()
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
local_ipv6()
}
}
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::{ use tokio::{
io::AsyncWriteExt, io::AsyncWriteExt,
...@@ -111,6 +131,13 @@ impl TcpStreamServer { ...@@ -111,6 +131,13 @@ impl TcpStreamServer {
} }
pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> { pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
Self::new_with_resolver(options, DefaultIpResolver).await
}
pub async fn new_with_resolver<R: IpResolver>(
options: ServerOptions,
resolver: R,
) -> Result<Arc<Self>, PipelineError> {
let local_ip = match options.interface { let local_ip = match options.interface {
Some(interface) => { Some(interface) => {
let interfaces: HashMap<String, std::net::IpAddr> = let interfaces: HashMap<String, std::net::IpAddr> =
...@@ -124,16 +151,19 @@ impl TcpStreamServer { ...@@ -124,16 +151,19 @@ impl TcpStreamServer {
)))? )))?
.to_string() .to_string()
} }
None => local_ip() None => {
.or_else(|err| match err { let resolved_ip = resolver.local_ip().or_else(|err| match err {
Error::LocalIpAddressNotFound => { Error::LocalIpAddressNotFound => resolver.local_ipv6(),
// Fall back to IPv6 if no IPv4 addresses are found
local_ipv6()
}
_ => Err(err), _ => Err(err),
}) });
.unwrap()
.to_string(), match resolved_ip {
Ok(addr) => addr,
Err(Error::LocalIpAddressNotFound) => IpAddr::from([127, 0, 0, 1]),
Err(err) => return Err(err.into()),
}
.to_string()
}
}; };
let state = Arc::new(Mutex::new(State::default())); let state = Arc::new(Mutex::new(State::default()));
...@@ -611,3 +641,128 @@ fn process_control_message(message: Bytes) -> Result<ControlAction> { ...@@ -611,3 +641,128 @@ fn process_control_message(message: Bytes) -> Result<ControlAction> {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::AsyncEngineContextProvider;
use crate::pipeline::Context;
// Mock resolver that always fails to simulate the fallback scenario
struct FailingIpResolver;
impl IpResolver for FailingIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
Err(Error::LocalIpAddressNotFound)
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
Err(Error::LocalIpAddressNotFound)
}
}
#[tokio::test]
async fn test_tcp_stream_server_default_behavior() {
// Test that TcpStreamServer::new works with default options
// This verifies normal operation when IP detection succeeds
let options = ServerOptions::default();
let result = TcpStreamServer::new(options).await;
assert!(
result.is_ok(),
"TcpStreamServer::new should succeed with default options"
);
let server = result.unwrap();
// Verify the server can be used by registering a stream
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
// Verify connection info is available and valid
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
// Should have a valid port assigned
assert!(
socket_addr.port() > 0,
"Server should be assigned a valid port number"
);
println!(
"Server created successfully with address: {}",
tcp_info.address
);
}
#[tokio::test]
async fn test_tcp_stream_server_fallback_to_loopback() {
// Test fallback behavior using a mock resolver that always fails
// This guarantees the fallback logic is triggered
let options = ServerOptions::builder().port(0).build().unwrap();
// Use the failing resolver to force the fallback
let result = TcpStreamServer::new_with_resolver(options, FailingIpResolver).await;
assert!(
result.is_ok(),
"Server creation should succeed with fallback even when IP detection fails"
);
let server = result.unwrap();
// Get the actual bound address by registering a stream
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
// With the failing resolver, fallback should ALWAYS be used
let ip = socket_addr.ip();
assert!(
ip.is_loopback(),
"Should use loopback when IP detection fails"
);
// Verify it's specifically 127.0.0.1 (the fallback value from the patch)
assert_eq!(
ip,
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
"Fallback should use exactly 127.0.0.1, got: {}",
ip
);
println!("SUCCESS: Fallback to 127.0.0.1 was confirmed: {}", ip);
// The server should work with the fallback IP
assert!(socket_addr.port() > 0, "Server should have a valid port");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment