Commit dddebc0d authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

fix: tcp retry and error handling updates (#169)


Signed-off-by: default avatarRyan Olson <ryanolson@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent c0e008b4
...@@ -66,7 +66,14 @@ impl PushEndpoint { ...@@ -66,7 +66,14 @@ impl PushEndpoint {
tokio::spawn(async move { tokio::spawn(async move {
tracing::trace!(worker_id, "handling new request"); tracing::trace!(worker_id, "handling new request");
let result = ingress.handle_payload(req.message.payload).await; let result = ingress.handle_payload(req.message.payload).await;
tracing::trace!(worker_id, "request handled: {:?}", result); match result {
Ok(_) => {
tracing::trace!(worker_id, "request handled successfully");
}
Err(e) => {
tracing::warn!("Failed to handle request: {:?}", e);
}
}
}); });
} else { } else {
break; break;
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
// limitations under the License. // limitations under the License.
use super::*; use super::*;
use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[async_trait] #[async_trait]
......
...@@ -66,18 +66,18 @@ impl From<TcpStreamConnectionInfo> for ConnectionInfo { ...@@ -66,18 +66,18 @@ impl From<TcpStreamConnectionInfo> for ConnectionInfo {
} }
impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo { impl TryFrom<ConnectionInfo> for TcpStreamConnectionInfo {
type Error = String; type Error = anyhow::Error;
fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> { fn try_from(info: ConnectionInfo) -> Result<Self, Self::Error> {
if info.transport != TCP_TRANSPORT { if info.transport != TCP_TRANSPORT {
return Err(format!( return Err(anyhow::anyhow!(
"Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed", "Invalid transport; TcpClient requires the transport to be `tcp_server`; however {} was passed",
info.transport info.transport
)); ));
} }
serde_json::from_str(&info.info) serde_json::from_str(&info.info)
.map_err(|e| format!("Failed parse ConnectionInfo: {:?}", e)) .map_err(|e| anyhow::anyhow!("Failed parse ConnectionInfo: {:?}", e))
} }
} }
......
...@@ -26,7 +26,8 @@ use crate::pipeline::network::{ ...@@ -26,7 +26,8 @@ use crate::pipeline::network::{
codec::{TwoPartCodec, TwoPartMessage}, codec::{TwoPartCodec, TwoPartMessage},
tcp::StreamType, tcp::StreamType,
ConnectionInfo, ResponseStreamPrologue, StreamSender, ConnectionInfo, ResponseStreamPrologue, StreamSender,
}; // Import SinkExt to use the `send` method };
use crate::{error, ErrorContext, Result}; // Import SinkExt to use the `send` method
#[allow(dead_code)] #[allow(dead_code)]
pub struct TcpClient { pub struct TcpClient {
...@@ -46,34 +47,49 @@ impl TcpClient { ...@@ -46,34 +47,49 @@ impl TcpClient {
TcpClient { worker_id } TcpClient { worker_id }
} }
async fn connect(address: &str) -> Result<TcpStream, String> { async fn connect(address: &str) -> std::io::Result<TcpStream> {
let socket = TcpStream::connect(address) // try to connect to the address; retry with exponential backoff if AddrNotAvailable
.await let backoff = std::time::Duration::from_millis(200);
.map_err(|e| format!("failed to connect: {:?}", e))?; loop {
match TcpStream::connect(address).await {
Ok(socket) => {
socket.set_nodelay(true)?;
return Ok(socket);
}
Err(e) => {
if e.kind() == std::io::ErrorKind::AddrNotAvailable {
tracing::warn!("retry warning: failed to connect: {:?}", e);
socket // TODO(#173) - remove with resolution of issue
.set_nodelay(true) #[cfg(debug_assertions)]
.map_err(|e| format!("failed to set nodelay: {:?}", e))?; eprintln!("retry warning: failed to connect: {:?}", e);
Ok(socket) tokio::time::sleep(backoff).await;
} else {
return Err(e);
}
}
}
}
} }
pub async fn create_response_steam( pub async fn create_response_steam(
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
info: ConnectionInfo, info: ConnectionInfo,
) -> Result<StreamSender, String> { ) -> Result<StreamSender> {
let info = TcpStreamConnectionInfo::try_from(info)?; let info =
TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
tracing::trace!("Creating response stream for {:?}", info); tracing::trace!("Creating response stream for {:?}", info);
if info.stream_type != StreamType::Response { if info.stream_type != StreamType::Response {
return Err(format!( return Err(error!(
"Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed", "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
info.stream_type info.stream_type
)); ));
} }
if info.context != context.id() { if info.context != context.id() {
return Err(format!( return Err(error!(
"Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed", "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
context.id(), context.id(),
info.context info.context
...@@ -93,7 +109,7 @@ impl TcpClient { ...@@ -93,7 +109,7 @@ impl TcpClient {
// captured by the monitor task // captured by the monitor task
let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>(); let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(control_handler(framed_reader, context, alive_tx)); let reader_task = tokio::spawn(handle_reader(framed_reader, context, alive_tx));
// transport specific handshake message // transport specific handshake message
let handshake = CallHomeHandshake { let handshake = CallHomeHandshake {
...@@ -104,7 +120,7 @@ impl TcpClient { ...@@ -104,7 +120,7 @@ impl TcpClient {
let handshake_bytes = match serde_json::to_vec(&handshake) { let handshake_bytes = match serde_json::to_vec(&handshake) {
Ok(hb) => hb, Ok(hb) => hb,
Err(err) => { Err(err) => {
return Err(format!( return Err(error!(
"create_response_steam: Error converting CallHomeHandshake to JSON array: {err:#}" "create_response_steam: Error converting CallHomeHandshake to JSON array: {err:#}"
)); ));
} }
...@@ -115,13 +131,35 @@ impl TcpClient { ...@@ -115,13 +131,35 @@ impl TcpClient {
framed_writer framed_writer
.send(msg) .send(msg)
.await .await
.map_err(|e| format!("failed to send handshake: {:?}", e))?; .map_err(|e| error!("failed to send handshake: {:?}", e))?;
// set up the channel to send bytes to the transport layer // set up the channel to send bytes to the transport layer
let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(16); let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(16);
// forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
tokio::spawn(forward_handler(bytes_rx, framed_writer, alive_rx));
let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx));
tokio::spawn(async move {
// await both tasks
let (reader, writer) = tokio::join!(reader_task, writer_task);
match (reader, writer) {
(Ok(reader), Ok(writer)) => {
let reader = reader.into_inner();
let writer = writer.into_inner();
let mut stream = reader.unsplit(writer);
// close the stream
Ok(stream.shutdown().await?)
}
_ => {
tracing::error!("failed to join reader and writer tasks");
anyhow::bail!("failed to join reader and writer tasks");
}
}
});
// set up the prologue for the stream // set up the prologue for the stream
// this might have transport specific metadata in the future // this might have transport specific metadata in the future
...@@ -137,13 +175,13 @@ impl TcpClient { ...@@ -137,13 +175,13 @@ impl TcpClient {
} }
} }
/// monitors the channel for a cancellation signal async fn handle_reader(
/// this task exits when the alive_rx half of the oneshot channel is closed or a stop/kill signal is received framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
async fn control_handler(
mut framed_reader: FramedRead<ReadHalf<TcpStream>, TwoPartCodec>,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
mut alive_tx: tokio::sync::oneshot::Sender<()>, alive_tx: tokio::sync::oneshot::Sender<()>,
) { ) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
let mut framed_reader = framed_reader;
let mut alive_tx = alive_tx;
loop { loop {
tokio::select! { tokio::select! {
msg = framed_reader.next() => { msg = framed_reader.next() => {
...@@ -151,15 +189,16 @@ async fn control_handler( ...@@ -151,15 +189,16 @@ async fn control_handler(
Some(Ok(two_part_msg)) => { Some(Ok(two_part_msg)) => {
match two_part_msg.optional_parts() { match two_part_msg.optional_parts() {
(Some(bytes), None) => { (Some(bytes), None) => {
let msg: ControlMessage = match serde_json::from_slice(bytes) { let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
Ok(msg) => msg, Ok(msg) => msg,
Err(err) => { Err(_) => {
let json_str = String::from_utf8_lossy(bytes); // TODO(#171) - address fatal errors
tracing::error!(%err, %json_str, "control_handler fatal error deserializing JSON to ControlMessage"); tracing::error!("fatal error - invalid control message detected");
context.kill();
break; break;
} }
}; };
match msg { match msg {
ControlMessage::Stop => { ControlMessage::Stop => {
context.stop(); context.stop();
...@@ -172,17 +211,23 @@ async fn control_handler( ...@@ -172,17 +211,23 @@ async fn control_handler(
} }
} }
_ => { _ => {
// we should not receive this // not a control message, so we just continue
continue;
} }
} }
} }
Some(Err(e)) => { Some(Err(_)) => {
panic!("failed to decode message from stream: {:?}", e); // TODO(#171) - address fatal errors
// break; // in this case the binary representation of the message is invalid
tracing::error!("fatal error - failed to decode message from stream");
break;
} }
None => { None => {
// the stream was closed, we should stop the stream // let mut writer = framed_reader.into_inner();
return; // if let Err(e) = writer.shutdown().await {
// tracing::trace!("failed to shutdown reader: {:?}", e);
// }
break;
} }
} }
} }
...@@ -192,24 +237,26 @@ async fn control_handler( ...@@ -192,24 +237,26 @@ async fn control_handler(
} }
} }
} }
// framed_writer.get_mut().shutdown().await.unwrap(); framed_reader
} }
async fn forward_handler( async fn handle_writer(
mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>, mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
mut framed_writer: FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>,
alive_rx: tokio::sync::oneshot::Receiver<()>, alive_rx: tokio::sync::oneshot::Receiver<()>,
) { ) -> FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec> {
while let Some(msg) = bytes_rx.recv().await { while let Some(msg) = bytes_rx.recv().await {
if let Err(e) = framed_writer.send(msg).await { if let Err(e) = framed_writer.send(msg).await {
tracing::trace!(%e, "failed to send message to stream; possible disconnect"); tracing::trace!(
"failed to send message to stream; possible disconnect: {:?}",
e
);
// TODO - possibly propagate the error upstream // TODO - possibly propagate the error upstream
break; break;
} }
} }
drop(alive_rx); drop(alive_rx);
if let Err(e) = framed_writer.get_mut().shutdown().await {
tracing::trace!("failed to shutdown writer: {:?}", e); framed_writer
}
} }
...@@ -13,9 +13,14 @@ ...@@ -13,9 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use anyhow::Result;
use core::panic; use core::panic;
use std::{collections::HashMap, sync::Arc}; use socket2::{Domain, SockAddr, Socket, Type};
use std::{
collections::HashMap,
net::{SocketAddr, TcpListener},
os::fd::{AsFd, FromRawFd},
sync::Arc,
};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use bytes::Bytes; use bytes::Bytes;
...@@ -26,6 +31,7 @@ use serde::{Deserialize, Serialize}; ...@@ -26,6 +31,7 @@ use serde::{Deserialize, Serialize};
use tokio::{ use tokio::{
io::AsyncWriteExt, io::AsyncWriteExt,
sync::{mpsc, oneshot}, sync::{mpsc, oneshot},
time,
}; };
use tokio_util::codec::{FramedRead, FramedWrite}; use tokio_util::codec::{FramedRead, FramedWrite};
...@@ -42,6 +48,7 @@ use crate::pipeline::{ ...@@ -42,6 +48,7 @@ use crate::pipeline::{
}, },
PipelineError, PipelineError,
}; };
use crate::{error, ErrorContext, Result};
#[allow(dead_code)] #[allow(dead_code)]
type ResponseType = TwoPartMessage; type ResponseType = TwoPartMessage;
...@@ -107,7 +114,7 @@ struct RequestedRecvConnection { ...@@ -107,7 +114,7 @@ struct RequestedRecvConnection {
struct State { struct State {
tx_subjects: HashMap<String, RequestedSendConnection>, tx_subjects: HashMap<String, RequestedSendConnection>,
rx_subjects: HashMap<String, RequestedRecvConnection>, rx_subjects: HashMap<String, RequestedRecvConnection>,
handle: Option<tokio::task::JoinHandle<()>>, handle: Option<tokio::task::JoinHandle<Result<()>>>,
} }
impl TcpStreamServer { impl TcpStreamServer {
...@@ -140,7 +147,7 @@ impl TcpStreamServer { ...@@ -140,7 +147,7 @@ impl TcpStreamServer {
PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e)) PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
})?; })?;
tracing::info!("TcpStreamServer started on {}:{}", local_ip, local_port); tracing::info!("tcp transport service on {}:{}", local_ip, local_port);
Ok(Arc::new(Self { Ok(Arc::new(Self {
local_ip, local_ip,
...@@ -273,7 +280,7 @@ async fn tcp_listener( ...@@ -273,7 +280,7 @@ async fn tcp_listener(
addr: String, addr: String,
state: Arc<Mutex<State>>, state: Arc<Mutex<State>>,
read_tx: tokio::sync::oneshot::Sender<Result<u16>>, read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
) { ) -> Result<()> {
let listener = tokio::net::TcpListener::bind(&addr) let listener = tokio::net::TcpListener::bind(&addr)
.await .await
.map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e)); .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
...@@ -293,20 +300,57 @@ async fn tcp_listener( ...@@ -293,20 +300,57 @@ async fn tcp_listener(
} }
Err(e) => { Err(e) => {
read_tx.send(Err(e)).expect("Failed to send ready signal"); read_tx.send(Err(e)).expect("Failed to send ready signal");
return; return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
} }
}; };
// TODO(#173) - alternative / not fully functional exploration for #173; removed when resolved.
// let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
// // Set the socket options
// socket.set_reuse_address(true)?;
// socket.set_nodelay(true)?;
// let addr: SocketAddr = addr.parse()?;
// //let addr: SocketAddr = "[::1]:0".parse()?;
// socket.bind(&addr.into())?;
// socket.listen(128)?;
// let listener: TcpListener = socket.into();
// let listener = tokio::net::TcpListener::from_std(listener)?;
// let addr = listener
// .local_addr()
// .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))?;
// read_tx
// .send(Ok(addr.port()))
// .expect("Failed to send ready signal");
loop { loop {
// todo - add instrumentation
// todo - add counter for all accepted connections
// todo - add gauge for all inflight connections
// todo - add counter for incoming bytes
// todo - add counter for outgoing bytes
let (stream, _addr) = match listener.accept().await { let (stream, _addr) = match listener.accept().await {
Ok(x) => x, Ok((stream, _addr)) => (stream, _addr),
Err(err) => { Err(e) => {
// TODO: Probably this is normal, user Ctrl-C something like that, find out // the client should retry, so we don't need to abort
tracing::info!(%err, addr, "TCP listener closed"); tracing::warn!("failed to accept tcp connection: {}", e);
break; eprintln!("failed to accept tcp connection: {}", e);
continue;
} }
}; };
stream.set_nodelay(true).unwrap();
match stream.set_nodelay(true) {
Ok(_) => (),
Err(e) => {
tracing::warn!("failed to set tcp stream to nodelay: {}", e);
}
}
tokio::spawn(handle_connection(stream, state.clone())); tokio::spawn(handle_connection(stream, state.clone()));
} }
...@@ -315,17 +359,18 @@ async fn tcp_listener( ...@@ -315,17 +359,18 @@ async fn tcp_listener(
async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) { async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
let result = process_stream(stream, state).await; let result = process_stream(stream, state).await;
match result { match result {
Ok(_) => tracing::trace!("TcpStream connection closed"), Ok(_) => tracing::trace!("successfully processed tcp connection"),
Err(e) => tracing::error!("TcpStream connection failed: {}", e), Err(e) => {
tracing::warn!("failed to handle tcp connection: {}", e);
#[cfg(debug_assertions)]
eprintln!("failed to handle tcp connection: {}", e);
}
} }
} }
/// This method is responsible for the internal tcp stream handshake /// This method is responsible for the internal tcp stream handshake
/// The handshake will specialize the stream as a request/sender or response/receiver stream /// The handshake will specialize the stream as a request/sender or response/receiver stream
async fn process_stream( async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
stream: tokio::net::TcpStream,
state: Arc<Mutex<State>>,
) -> Result<(), String> {
// split the socket in to a reader and writer // split the socket in to a reader and writer
let (read_half, write_half) = tokio::io::split(stream); let (read_half, write_half) = tokio::io::split(stream);
...@@ -338,19 +383,18 @@ async fn tcp_listener( ...@@ -338,19 +383,18 @@ async fn tcp_listener(
let first_message = framed_reader let first_message = framed_reader
.next() .next()
.await .await
.ok_or("Connection closed without a ControlMessge".to_string())? .ok_or(error!("Connection closed without a ControlMessage"))??;
.map_err(|e| e.to_string())?;
// we await on the raw bytes which should come in as a header only message // we await on the raw bytes which should come in as a header only message
// todo - improve error handling - check for no data // todo - improve error handling - check for no data
let handshake: CallHomeHandshake = match first_message.header() { let handshake: CallHomeHandshake = match first_message.header() {
Some(header) => serde_json::from_slice(header).map_err(|e| { Some(header) => serde_json::from_slice(header).map_err(|e| {
format!( error!(
"Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}", "Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
) )
})?, })?,
None => { None => {
return Err("Expected ControlMessage, got DataMessage".to_string()); return Err(error!("Expected ControlMessage, got DataMessage"));
} }
}; };
...@@ -362,10 +406,9 @@ async fn tcp_listener( ...@@ -362,10 +406,9 @@ async fn tcp_listener(
.await .await
} }
} }
.map_err(|e| format!("Failed to process stream: {e}"))
} }
async fn process_request_stream() -> Result<(), String> { async fn process_request_stream() -> Result<()> {
Ok(()) Ok(())
} }
...@@ -374,12 +417,12 @@ async fn tcp_listener( ...@@ -374,12 +417,12 @@ async fn tcp_listener(
state: Arc<Mutex<State>>, state: Arc<Mutex<State>>,
mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>, mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>, writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
) -> Result<(), String> { ) -> Result<()> {
let response_stream = state let response_stream = state
.lock().await .lock().await
.rx_subjects .rx_subjects
.remove(&subject) .remove(&subject)
.ok_or(format!("Subject not found: {subject}; upstream publisher specified a subject unknown to the downsteam subscriber"))?; .ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
// unwrap response_stream // unwrap response_stream
let RequestedRecvConnection { let RequestedRecvConnection {
...@@ -392,14 +435,13 @@ async fn tcp_listener( ...@@ -392,14 +435,13 @@ async fn tcp_listener(
let prologue = reader let prologue = reader
.next() .next()
.await .await
.ok_or("Connection closed without a ControlMessge".to_string())? .ok_or(error!("Connection closed without a ControlMessge"))??;
.map_err(|e| e.to_string())?;
// deserialize prologue // deserialize prologue
let prologue = match prologue.into_message_type() { let prologue = match prologue.into_message_type() {
TwoPartMessageType::HeaderOnly(header) => { TwoPartMessageType::HeaderOnly(header) => {
let prologue: ResponseStreamPrologue = serde_json::from_slice(&header) let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
.map_err(|e| format!("Failed to deserialize ControlMessage: {}", e))?; .map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
prologue prologue
} }
_ => { _ => {
...@@ -415,7 +457,7 @@ async fn tcp_listener( ...@@ -415,7 +457,7 @@ async fn tcp_listener(
// us to trace the initial setup time vs the time to prologue // us to trace the initial setup time vs the time to prologue
if let Some(error) = &prologue.error { if let Some(error) = &prologue.error {
let _ = connection.send(Err(error.clone())); let _ = connection.send(Err(error.clone()));
return Err(format!("Received error prologue: {}", error)); return Err(error!("Received error prologue: {}", error));
} }
// we need to know the buffer size from the registration options; add this to the RequestRecvConnection object // we need to know the buffer size from the registration options; add this to the RequestRecvConnection object
...@@ -425,7 +467,7 @@ async fn tcp_listener( ...@@ -425,7 +467,7 @@ async fn tcp_listener(
.send(Ok(crate::pipeline::network::StreamReceiver { rx })) .send(Ok(crate::pipeline::network::StreamReceiver { rx }))
.is_err() .is_err()
{ {
return Err("The requester of the stream has been dropped before the connection was established".to_string()); return Err(error!("The requester of the stream has been dropped before the connection was established"));
} }
let (alive_tx, alive_rx) = mpsc::channel::<()>(1); let (alive_tx, alive_rx) = mpsc::channel::<()>(1);
...@@ -449,13 +491,8 @@ async fn tcp_listener( ...@@ -449,13 +491,8 @@ async fn tcp_listener(
// check the results of each of the tasks // check the results of each of the tasks
let (monitor_result, forward_result) = tokio::join!(monitor_task, forward_task); let (monitor_result, forward_result) = tokio::join!(monitor_task, forward_task);
// if either of the tasks failed, we need to return an error monitor_result??;
if let Err(e) = monitor_result { forward_result??;
return Err(format!("Monitor task failed: {}", e));
}
if let Err(e) = forward_result {
return Err(format!("Forward task failed: {}", e));
}
Ok(()) Ok(())
} }
...@@ -466,7 +503,7 @@ async fn tcp_listener( ...@@ -466,7 +503,7 @@ async fn tcp_listener(
control_tx: mpsc::Sender<Bytes>, control_tx: mpsc::Sender<Bytes>,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
alive_rx: mpsc::Receiver<()>, alive_rx: mpsc::Receiver<()>,
) -> Result<(), String> { ) -> Result<()> {
// loop over reading the tcp stream and checking if the writer is closed // loop over reading the tcp stream and checking if the writer is closed
loop { loop {
tokio::select! { tokio::select! {
...@@ -481,12 +518,12 @@ async fn tcp_listener( ...@@ -481,12 +518,12 @@ async fn tcp_listener(
if !data.is_empty() { if !data.is_empty() {
if let Err(err) = response_tx.send(data).await { if let Err(err) = response_tx.send(data).await {
return Err(format!("handle_response_stream: Failed sending to response_tx: {err}")); return Err(error!("handle_response_stream: Failed sending to response_tx: {err}"));
}; };
} }
} }
Some(Err(e)) => { Some(Err(e)) => {
return Err(format!("Failed to read TwoPartCodec message from TcpStream: {}", e)); return Err(error!("Failed to read TwoPartCodec message from TcpStream: {}", e));
} }
None => { None => {
tracing::trace!("TcpStream closed naturally"); tracing::trace!("TcpStream closed naturally");
...@@ -536,7 +573,7 @@ async fn tcp_listener( ...@@ -536,7 +573,7 @@ async fn tcp_listener(
_socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>, _socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
ctx: Arc<dyn AsyncEngineContext>, ctx: Arc<dyn AsyncEngineContext>,
alive_tx: mpsc::Sender<()>, alive_tx: mpsc::Sender<()>,
) { ) -> Result<()> {
let alive_tx = alive_tx; let alive_tx = alive_tx;
tokio::select! { tokio::select! {
_ = ctx.stopped() => { _ = ctx.stopped() => {
...@@ -547,10 +584,10 @@ async fn tcp_listener( ...@@ -547,10 +584,10 @@ async fn tcp_listener(
tracing::trace!("response stream closed naturally") tracing::trace!("response stream closed naturally")
} }
} }
let mut framed_writer = _socket_tx; let framed_writer = _socket_tx;
if let Err(err) = framed_writer.get_mut().shutdown().await { let mut inner = framed_writer.into_inner();
// TODO: This might be fine to ignore inner.flush().await?;
tracing::error!("monitor shutdown error: {err}"); inner.shutdown().await?;
} Ok(())
} }
} }
...@@ -17,6 +17,8 @@ use serde::{Deserialize, Serialize}; ...@@ -17,6 +17,8 @@ use serde::{Deserialize, Serialize};
pub mod annotated; pub mod annotated;
pub type LeaseId = i64;
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Component { pub struct Component {
pub name: String, pub name: String,
...@@ -25,9 +27,18 @@ pub struct Component { ...@@ -25,9 +27,18 @@ pub struct Component {
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Endpoint { pub struct Endpoint {
/// Name of the endpoint.
pub name: String, pub name: String,
/// Component of the endpoint.
pub component: String, pub component: String,
/// Namespace of the component.
pub namespace: String, pub namespace: String,
/// Optional lease id for the endpoint.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lease: Option<LeaseId>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
...@@ -71,6 +82,7 @@ mod tests { ...@@ -71,6 +82,7 @@ mod tests {
name: "test_endpoint".to_string(), name: "test_endpoint".to_string(),
component: "test_component".to_string(), component: "test_component".to_string(),
namespace: "test_namespace".to_string(), namespace: "test_namespace".to_string(),
lease: None,
}; };
assert_eq!(endpoint.name, "test_endpoint"); assert_eq!(endpoint.name, "test_endpoint");
......
...@@ -99,15 +99,22 @@ impl Client { ...@@ -99,15 +99,22 @@ impl Client {
let token = runtime.primary_token(); let token = runtime.primary_token();
let client = let client =
etcd_client::Client::connect(config.etcd_url, config.etcd_connect_options).await?; etcd_client::Client::connect(config.etcd_url, config.etcd_connect_options).await?;
let lease_id = if config.attach_lease {
let lease_client = client.lease_client(); let lease_client = client.lease_client();
let lease = create_lease(lease_client, 10, token) let lease = create_lease(lease_client, 10, token)
.await .await
.context("creating primary lease")?; .context("creating primary lease")?;
lease.id
} else {
0
};
Ok(Client { Ok(Client {
client, client,
primary_lease: lease.id, primary_lease: lease_id,
runtime, runtime,
}) })
} }
...@@ -260,6 +267,10 @@ pub struct ClientOptions { ...@@ -260,6 +267,10 @@ pub struct ClientOptions {
#[builder(default)] #[builder(default)]
etcd_connect_options: Option<ConnectOptions>, etcd_connect_options: Option<ConnectOptions>,
/// If true, the client will attach a lease to the primary [`CancellationToken`].
#[builder(default = "true")]
attach_lease: bool,
} }
impl Default for ClientOptions { impl Default for ClientOptions {
...@@ -267,6 +278,7 @@ impl Default for ClientOptions { ...@@ -267,6 +278,7 @@ impl Default for ClientOptions {
ClientOptions { ClientOptions {
etcd_url: default_servers(), etcd_url: default_servers(),
etcd_connect_options: None, etcd_connect_options: None,
attach_lease: true,
} }
} }
} }
......
...@@ -19,14 +19,15 @@ mod integration { ...@@ -19,14 +19,15 @@ mod integration {
pub const DEFAULT_NAMESPACE: &str = "triton-init"; pub const DEFAULT_NAMESPACE: &str = "triton-init";
use futures::StreamExt; use futures::StreamExt;
use std::sync::Arc; use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use triton_distributed::{ use triton_distributed::{
pipeline::{ pipeline::{
async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut,
ResponseStream, SingleIn, ResponseStream, SingleIn,
}, },
protocols::annotated::Annotated, protocols::annotated::Annotated,
DistributedRuntime, Result, Runtime, Worker, DistributedRuntime, ErrorContext, Result, Runtime, Worker,
}; };
#[test] #[test]
...@@ -97,6 +98,14 @@ mod integration { ...@@ -97,6 +98,14 @@ mod integration {
} }
async fn client(runtime: DistributedRuntime) -> Result<()> { async fn client(runtime: DistributedRuntime) -> Result<()> {
// get the run duration from env
let run_duration = std::env::var("TRD_SOAK_RUN_DURATION").unwrap_or("1m".to_string());
let run_duration =
humantime::parse_duration(&run_duration).unwrap_or(Duration::from_secs(60));
let batch_load = std::env::var("TRD_SOAK_BATCH_LOAD").unwrap_or("1000".to_string());
let batch_load: usize = batch_load.parse().unwrap_or(1000);
let client = runtime let client = runtime
.namespace(DEFAULT_NAMESPACE)? .namespace(DEFAULT_NAMESPACE)?
.component("backend")? .component("backend")?
...@@ -107,13 +116,26 @@ mod integration { ...@@ -107,13 +116,26 @@ mod integration {
client.wait_for_endpoints().await?; client.wait_for_endpoints().await?;
let client = Arc::new(client); let client = Arc::new(client);
// spawn 20000 tasks to put load on the server let start = Instant::now();
let mut count = 0;
loop {
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for _ in 0..20000 { for _ in 0..batch_load {
let client = client.clone(); let client = client.clone();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let mut stream = client.random("hello world".to_string().into()).await?; let mut stream = tokio::time::timeout(
while let Some(_resp) = stream.next().await {} Duration::from_secs(30),
client.random("hello world".to_string().into()),
)
.await
.context("request timed out")??;
while let Some(_resp) =
tokio::time::timeout(Duration::from_secs(30), stream.next())
.await
.context("stream timed out")?
{}
Ok::<(), Error>(()) Ok::<(), Error>(())
})); }));
} }
...@@ -122,6 +144,16 @@ mod integration { ...@@ -122,6 +144,16 @@ mod integration {
task.await??; task.await??;
} }
let elapsed = start.elapsed();
count += batch_load;
println!("elapsed: {:?}; count: {}", elapsed, count);
if elapsed > run_duration {
println!("done");
break;
}
}
Ok(()) Ok(())
} }
} }
{
"folders": [
{
"path": "."
}
],
"settings": {
"rust-analyzer.linkedProjects": [
"llm/rust/Cargo.toml",
"runtime/rust/Cargo.toml",
"runtime/rust/python-wheel/Cargo.toml",
"examples/rust/Cargo.toml"
],
"rust-analyzer.procMacro.enable": true,
},
"extensions": {
"recommendations": [
"ms-python.python",
"rust-lang.rust-analyzer"
]
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment