Commit b705549c authored by Graham King's avatar Graham King Committed by GitHub
Browse files

refactor: Handle JSON serialize / de-serialize errors. (#156)

parent 139a9a83
...@@ -338,7 +338,16 @@ async fn process_stream( ...@@ -338,7 +338,16 @@ async fn process_stream(
let mut stream = stream; let mut stream = stream;
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
// Convert the response to a PyObject using Python's GIL // Convert the response to a PyObject using Python's GIL
let annotated: RsAnnotated<serde_json::Value> = serde_json::from_value(response).unwrap(); // TODO: Remove the clone, but still log the full JSON string on error. But how?
let annotated: RsAnnotated<serde_json::Value> = match serde_json::from_value(
response.clone(),
) {
Ok(a) => a,
Err(err) => {
tracing::error!(%err, %response, "process_stream: Failed de-serializing JSON into RsAnnotated");
break;
}
};
let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| { let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| {
let result = Python::with_gil(|py| match pythonize::pythonize(py, &data) { let result = Python::with_gil(|py| match pythonize::pythonize(py, &data) {
......
...@@ -29,14 +29,15 @@ pub struct WorkerConfig { ...@@ -29,14 +29,15 @@ pub struct WorkerConfig {
} }
impl WorkerConfig { impl WorkerConfig {
/// Instantiates and reads server configurations from appropriate sources.
/// Panics on invalid configuration.
pub fn from_settings() -> Self { pub fn from_settings() -> Self {
// Instantiates and reads server configurations from appropriate sources.
// All calls should be global and thread safe. // All calls should be global and thread safe.
Figment::new() Figment::new()
.merge(Serialized::defaults(Self::default())) .merge(Serialized::defaults(Self::default()))
.merge(Env::prefixed("TRITON_WORKER_")) .merge(Env::prefixed("TRITON_WORKER_"))
.extract() .extract()
.unwrap() .unwrap() // safety: Called on startup, so panic is reasonable
} }
} }
......
...@@ -152,10 +152,15 @@ impl StreamSender { ...@@ -152,10 +152,15 @@ impl StreamSender {
pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> { pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
if let Some(prologue) = self.prologue.take() { if let Some(prologue) = self.prologue.take() {
let prologue = ResponseStreamPrologue { error, ..prologue }; let prologue = ResponseStreamPrologue { error, ..prologue };
let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
Ok(b) => b.into(),
Err(err) => {
tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
return Err("Invalid prologue".to_string());
}
};
self.tx self.tx
.send(TwoPartMessage::from_header( .send(TwoPartMessage::from_header(header_bytes))
serde_json::to_vec(&prologue).unwrap().into(),
))
.await .await
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
} else { } else {
......
...@@ -127,8 +127,18 @@ where ...@@ -127,8 +127,18 @@ where
// next build the two part message where we package the connection info and the request into // next build the two part message where we package the connection info and the request into
// a single Vec<u8> that can be sent over the wire. // a single Vec<u8> that can be sent over the wire.
// --- package this up in the WorkQueuePublisher --- // --- package this up in the WorkQueuePublisher ---
let ctrl = serde_json::to_vec(&control_message).unwrap(); let ctrl = match serde_json::to_vec(&control_message) {
let data = serde_json::to_vec(&request).unwrap(); Ok(ctrl) => ctrl,
Err(err) => {
anyhow::bail!("Failed serializing RequestControlMessage to JSON array: {err}");
}
};
let data = match serde_json::to_vec(&request) {
Ok(data) => data,
Err(err) => {
anyhow::bail!("Failed serializing request to JSON array: {err}");
}
};
tracing::trace!( tracing::trace!(
"[req: {}] packaging two-part message; ctrl: {} bytes, data: {} bytes", "[req: {}] packaging two-part message; ctrl: {} bytes, data: {} bytes",
...@@ -143,7 +153,7 @@ where ...@@ -143,7 +153,7 @@ where
// or it should take a two part message directly // or it should take a two part message directly
// todo - update this // todo - update this
let codec = TwoPartCodec::default(); let codec = TwoPartCodec::default();
let buffer = codec.encode_message(msg).unwrap(); let buffer = codec.encode_message(msg)?;
// TRANSPORT ABSTRACT REQUIRED - END HERE // TRANSPORT ABSTRACT REQUIRED - END HERE
...@@ -164,9 +174,15 @@ where ...@@ -164,9 +174,15 @@ where
let stream = tokio_stream::wrappers::ReceiverStream::new(response_stream.rx); let stream = tokio_stream::wrappers::ReceiverStream::new(response_stream.rx);
let stream = stream.map(|msg| { let stream = stream.filter_map(|msg| async move {
let resp: U = serde_json::from_slice(&msg).unwrap(); match serde_json::from_slice::<U>(&msg) {
resp Ok(r) => Some(r),
Err(err) => {
let json_str = String::from_utf8_lossy(&msg);
tracing::error!(%err, %json_str, "Failed deserializing JSON to response");
None
}
}
}); });
Ok(ResponseStream::new(Box::pin(stream), engine_ctx)) Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
......
...@@ -38,7 +38,15 @@ where ...@@ -38,7 +38,15 @@ where
header.len(), header.len(),
data.len() data.len()
); );
let control_msg: RequestControlMessage = serde_json::from_slice(&header).unwrap(); let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
Ok(cm) => cm,
Err(err) => {
let json_str = String::from_utf8_lossy(&header);
return Err(PipelineError::DeserializationError(
format!("Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"),
));
}
};
let request: T = serde_json::from_slice(&data)?; let request: T = serde_json::from_slice(&data)?;
(control_msg, request) (control_msg, request)
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
use std::sync::Arc; use std::sync::Arc;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use tokio::io::{ReadHalf, WriteHalf};
use tokio::{io::AsyncWriteExt, net::TcpStream}; use tokio::{io::AsyncWriteExt, net::TcpStream};
use tokio_util::codec::{FramedRead, FramedWrite}; use tokio_util::codec::{FramedRead, FramedWrite};
...@@ -82,7 +83,7 @@ impl TcpClient { ...@@ -82,7 +83,7 @@ impl TcpClient {
let stream = TcpClient::connect(&info.address).await?; let stream = TcpClient::connect(&info.address).await?;
let (read_half, write_half) = tokio::io::split(stream); let (read_half, write_half) = tokio::io::split(stream);
let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default()); let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default()); let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
// this is a oneshot channel that will be used to signal when the stream is closed // this is a oneshot channel that will be used to signal when the stream is closed
...@@ -90,11 +91,59 @@ impl TcpClient { ...@@ -90,11 +91,59 @@ impl TcpClient {
// the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
// so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
// captured by the monitor task // captured by the monitor task
let (mut alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>(); let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
// monitors the channel for a cancellation signal tokio::spawn(control_handler(framed_reader, context, alive_tx));
// this task exits when the alive_rx half of the oneshot channel is closed or a stop/kill signal is received
tokio::spawn(async move { // transport specific handshake message
let handshake = CallHomeHandshake {
subject: info.subject,
stream_type: StreamType::Response,
};
let handshake_bytes = match serde_json::to_vec(&handshake) {
Ok(hb) => hb,
Err(err) => {
return Err(format!(
"create_response_steam: Error converting CallHomeHandshake to JSON array: {err:#}"
));
}
};
let msg = TwoPartMessage::from_header(handshake_bytes.into());
// issue the the first tcp handshake message
framed_writer
.send(msg)
.await
.map_err(|e| format!("failed to send handshake: {:?}", e))?;
// set up the channel to send bytes to the transport layer
let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(16);
// forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
tokio::spawn(forward_handler(bytes_rx, framed_writer, alive_rx));
// set up the prologue for the stream
// this might have transport specific metadata in the future
let prologue = Some(ResponseStreamPrologue { error: None });
// create the stream sender
let stream_sender = StreamSender {
tx: bytes_tx,
prologue,
};
Ok(stream_sender)
}
}
/// monitors the channel for a cancellation signal
/// this task exits when the alive_rx half of the oneshot channel is closed or a stop/kill signal is received
async fn control_handler(
mut framed_reader: FramedRead<ReadHalf<TcpStream>, TwoPartCodec>,
context: Arc<dyn AsyncEngineContext>,
mut alive_tx: tokio::sync::oneshot::Sender<()>,
) {
loop { loop {
tokio::select! { tokio::select! {
msg = framed_reader.next() => { msg = framed_reader.next() => {
...@@ -102,7 +151,15 @@ impl TcpClient { ...@@ -102,7 +151,15 @@ impl TcpClient {
Some(Ok(two_part_msg)) => { Some(Ok(two_part_msg)) => {
match two_part_msg.optional_parts() { match two_part_msg.optional_parts() {
(Some(bytes), None) => { (Some(bytes), None) => {
let msg: ControlMessage = serde_json::from_slice(bytes).unwrap(); let msg: ControlMessage = match serde_json::from_slice(bytes) {
Ok(msg) => msg,
Err(err) => {
let json_str = String::from_utf8_lossy(bytes);
tracing::error!(%err, %json_str, "control_handler fatal error deserializing JSON to ControlMessage");
context.kill();
break;
}
};
match msg { match msg {
ControlMessage::Stop => { ControlMessage::Stop => {
context.stop(); context.stop();
...@@ -136,34 +193,16 @@ impl TcpClient { ...@@ -136,34 +193,16 @@ impl TcpClient {
} }
} }
// framed_writer.get_mut().shutdown().await.unwrap(); // framed_writer.get_mut().shutdown().await.unwrap();
}); }
// transport specific handshake message
let handshake = CallHomeHandshake {
subject: info.subject,
stream_type: StreamType::Response,
};
let handshake_bytes = serde_json::to_vec(&handshake).unwrap();
let msg = TwoPartMessage::from_header(handshake_bytes.into());
// issue the the first tcp handshake message
framed_writer
.send(msg)
.await
.map_err(|e| format!("failed to send handshake: {:?}", e))?;
// set up the channel to send bytes to the transport layer
let (bytes_tx, mut bytes_rx) = tokio::sync::mpsc::channel(16);
// forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel async fn forward_handler(
tokio::spawn(async move { mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
mut framed_writer: FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>,
alive_rx: tokio::sync::oneshot::Receiver<()>,
) {
while let Some(msg) = bytes_rx.recv().await { while let Some(msg) = bytes_rx.recv().await {
if let Err(e) = framed_writer.send(msg).await { if let Err(e) = framed_writer.send(msg).await {
tracing::trace!( tracing::trace!(%e, "failed to send message to stream; possible disconnect");
"failed to send message to stream; possible disconnect: {:?}",
e
);
// TODO - possibly propagate the error upstream // TODO - possibly propagate the error upstream
break; break;
...@@ -173,18 +212,4 @@ impl TcpClient { ...@@ -173,18 +212,4 @@ impl TcpClient {
if let Err(e) = framed_writer.get_mut().shutdown().await { if let Err(e) = framed_writer.get_mut().shutdown().await {
tracing::trace!("failed to shutdown writer: {:?}", e); tracing::trace!("failed to shutdown writer: {:?}", e);
} }
});
// set up the prologue for the stream
// this might have transport specific metadata in the future
let prologue = Some(ResponseStreamPrologue { error: None });
// create the stream sender
let stream_sender = StreamSender {
tx: bytes_tx,
prologue,
};
Ok(stream_sender)
}
} }
...@@ -298,7 +298,14 @@ async fn tcp_listener( ...@@ -298,7 +298,14 @@ async fn tcp_listener(
}; };
loop { loop {
let (stream, _addr) = listener.accept().await.unwrap(); let (stream, _addr) = match listener.accept().await {
Ok(x) => x,
Err(err) => {
// TODO: Probably this is normal, user Ctrl-C something like that, find out
tracing::info!(%err, addr, "TCP listener closed");
break;
}
};
stream.set_nodelay(true).unwrap(); stream.set_nodelay(true).unwrap();
tokio::spawn(handle_connection(stream, state.clone())); tokio::spawn(handle_connection(stream, state.clone()));
} }
...@@ -336,18 +343,16 @@ async fn tcp_listener( ...@@ -336,18 +343,16 @@ async fn tcp_listener(
// we await on the raw bytes which should come in as a header only message // we await on the raw bytes which should come in as a header only message
// todo - improve error handling - check for no data // todo - improve error handling - check for no data
if first_message.header().is_none() { let handshake: CallHomeHandshake = match first_message.header() {
return Err("Expected ControlMessage, got DataMessage".to_string()); Some(header) => serde_json::from_slice(header).map_err(|e| {
}
// deserialize the [`CallHomeHandshake`] message
let handshake: CallHomeHandshake = serde_json::from_slice(first_message.header().unwrap())
.map_err(|e| {
format!( format!(
"Failed to deserialize the first message as a valid `CallHomeHandshake`: {}", "Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
e
) )
})?; })?,
None => {
return Err("Expected ControlMessage, got DataMessage".to_string());
}
};
// branch here to handle sender stream or receiver stream // branch here to handle sender stream or receiver stream
match handshake.stream_type { match handshake.stream_type {
...@@ -357,7 +362,7 @@ async fn tcp_listener( ...@@ -357,7 +362,7 @@ async fn tcp_listener(
.await .await
} }
} }
.map_err(|e| format!("Failed to process stream: {}", e)) .map_err(|e| format!("Failed to process stream: {e}"))
} }
async fn process_request_stream() -> Result<(), String> { async fn process_request_stream() -> Result<(), String> {
...@@ -374,7 +379,7 @@ async fn tcp_listener( ...@@ -374,7 +379,7 @@ async fn tcp_listener(
.lock().await .lock().await
.rx_subjects .rx_subjects
.remove(&subject) .remove(&subject)
.ok_or(format!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?; .ok_or(format!("Subject not found: {subject}; upstream publisher specified a subject unknown to the downsteam subscriber"))?;
// unwrap response_stream // unwrap response_stream
let RequestedRecvConnection { let RequestedRecvConnection {
...@@ -475,7 +480,9 @@ async fn tcp_listener( ...@@ -475,7 +480,9 @@ async fn tcp_listener(
} }
if !data.is_empty() { if !data.is_empty() {
response_tx.send(data).await.unwrap(); if let Err(err) = response_tx.send(data).await {
return Err(format!("handle_response_stream: Failed sending to response_tx: {err}"));
};
} }
} }
Some(Err(e)) => { Some(Err(e)) => {
...@@ -541,6 +548,9 @@ async fn tcp_listener( ...@@ -541,6 +548,9 @@ async fn tcp_listener(
} }
} }
let mut framed_writer = _socket_tx; let mut framed_writer = _socket_tx;
framed_writer.get_mut().shutdown().await.unwrap(); if let Err(err) = framed_writer.get_mut().shutdown().await {
// TODO: This might be fine to ignore
tracing::error!("monitor shutdown error: {err}");
}
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment