Unverified Commit 1208f017 authored by akshaver's avatar akshaver Committed by GitHub
Browse files

chore: add loopback as default address for network (#3250)

parent 65071f21
......@@ -5,7 +5,7 @@ use core::panic;
use socket2::{Domain, SockAddr, Socket, Type};
use std::{
collections::HashMap,
net::{SocketAddr, TcpListener},
net::{IpAddr, SocketAddr, TcpListener},
os::fd::{AsFd, FromRawFd},
sync::Arc,
};
......@@ -15,6 +15,26 @@ use bytes::Bytes;
use derive_builder::Builder;
use futures::{SinkExt, StreamExt};
use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
// Trait for IP address resolution - allows dependency injection for testing
pub trait IpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
}
// Default implementation using the real local_ip_address crate
pub struct DefaultIpResolver;
impl IpResolver for DefaultIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
local_ip()
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
local_ipv6()
}
}
use serde::{Deserialize, Serialize};
use tokio::{
io::AsyncWriteExt,
......@@ -111,6 +131,13 @@ impl TcpStreamServer {
}
pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
Self::new_with_resolver(options, DefaultIpResolver).await
}
pub async fn new_with_resolver<R: IpResolver>(
options: ServerOptions,
resolver: R,
) -> Result<Arc<Self>, PipelineError> {
let local_ip = match options.interface {
Some(interface) => {
let interfaces: HashMap<String, std::net::IpAddr> =
......@@ -124,16 +151,19 @@ impl TcpStreamServer {
)))?
.to_string()
}
None => local_ip()
.or_else(|err| match err {
Error::LocalIpAddressNotFound => {
// Fall back to IPv6 if no IPv4 addresses are found
local_ipv6()
}
None => {
let resolved_ip = resolver.local_ip().or_else(|err| match err {
Error::LocalIpAddressNotFound => resolver.local_ipv6(),
_ => Err(err),
})
.unwrap()
.to_string(),
});
match resolved_ip {
Ok(addr) => addr,
Err(Error::LocalIpAddressNotFound) => IpAddr::from([127, 0, 0, 1]),
Err(err) => return Err(err.into()),
}
.to_string()
}
};
let state = Arc::new(Mutex::new(State::default()));
......@@ -611,3 +641,128 @@ fn process_control_message(message: Bytes) -> Result<ControlAction> {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::AsyncEngineContextProvider;
use crate::pipeline::Context;
// Mock resolver that always fails to simulate the fallback scenario
struct FailingIpResolver;
impl IpResolver for FailingIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
Err(Error::LocalIpAddressNotFound)
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
Err(Error::LocalIpAddressNotFound)
}
}
#[tokio::test]
async fn test_tcp_stream_server_default_behavior() {
// Test that TcpStreamServer::new works with default options
// This verifies normal operation when IP detection succeeds
let options = ServerOptions::default();
let result = TcpStreamServer::new(options).await;
assert!(
result.is_ok(),
"TcpStreamServer::new should succeed with default options"
);
let server = result.unwrap();
// Verify the server can be used by registering a stream
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
// Verify connection info is available and valid
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
// Should have a valid port assigned
assert!(
socket_addr.port() > 0,
"Server should be assigned a valid port number"
);
println!(
"Server created successfully with address: {}",
tcp_info.address
);
}
#[tokio::test]
async fn test_tcp_stream_server_fallback_to_loopback() {
// Test fallback behavior using a mock resolver that always fails
// This guarantees the fallback logic is triggered
let options = ServerOptions::builder().port(0).build().unwrap();
// Use the failing resolver to force the fallback
let result = TcpStreamServer::new_with_resolver(options, FailingIpResolver).await;
assert!(
result.is_ok(),
"Server creation should succeed with fallback even when IP detection fails"
);
let server = result.unwrap();
// Get the actual bound address by registering a stream
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
// With the failing resolver, fallback should ALWAYS be used
let ip = socket_addr.ip();
assert!(
ip.is_loopback(),
"Should use loopback when IP detection fails"
);
// Verify it's specifically 127.0.0.1 (the fallback value from the patch)
assert_eq!(
ip,
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
"Fallback should use exactly 127.0.0.1, got: {}",
ip
);
println!("SUCCESS: Fallback to 127.0.0.1 was confirmed: {}", ip);
// The server should work with the fallback IP
assert!(socket_addr.port() > 0, "Server should have a valid port");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment