Commit b705549c authored by Graham King's avatar Graham King Committed by GitHub
Browse files

refactor: Handle JSON serialize / de-serialize errors. (#156)

parent 139a9a83
......@@ -338,7 +338,16 @@ async fn process_stream(
let mut stream = stream;
while let Some(response) = stream.next().await {
// Convert the response to a PyObject using Python's GIL
let annotated: RsAnnotated<serde_json::Value> = serde_json::from_value(response).unwrap();
// TODO: Remove the clone, but still log the full JSON string on error. But how?
let annotated: RsAnnotated<serde_json::Value> = match serde_json::from_value(
response.clone(),
) {
Ok(a) => a,
Err(err) => {
tracing::error!(%err, %response, "process_stream: Failed de-serializing JSON into RsAnnotated");
break;
}
};
let annotated: RsAnnotated<PyObject> = annotated.map_data(|data| {
let result = Python::with_gil(|py| match pythonize::pythonize(py, &data) {
......
......@@ -29,14 +29,15 @@ pub struct WorkerConfig {
}
impl WorkerConfig {
/// Instantiates and reads server configurations from appropriate sources.
/// Panics on invalid configuration.
pub fn from_settings() -> Self {
// Instantiates and reads server configurations from appropriate sources.
// All calls should be global and thread safe.
Figment::new()
.merge(Serialized::defaults(Self::default()))
.merge(Env::prefixed("TRITON_WORKER_"))
.extract()
.unwrap()
.unwrap() // safety: Called on startup, so panic is reasonable
}
}
......
......@@ -152,10 +152,15 @@ impl StreamSender {
pub async fn send_prologue(&mut self, error: Option<String>) -> Result<(), String> {
if let Some(prologue) = self.prologue.take() {
let prologue = ResponseStreamPrologue { error, ..prologue };
let header_bytes: Bytes = match serde_json::to_vec(&prologue) {
Ok(b) => b.into(),
Err(err) => {
tracing::error!(%err, "send_prologue: ResponseStreamPrologue did not serialize to a JSON array");
return Err("Invalid prologue".to_string());
}
};
self.tx
.send(TwoPartMessage::from_header(
serde_json::to_vec(&prologue).unwrap().into(),
))
.send(TwoPartMessage::from_header(header_bytes))
.await
.map_err(|e| e.to_string())?;
} else {
......
......@@ -127,8 +127,18 @@ where
// next build the two part message where we package the connection info and the request into
// a single Vec<u8> that can be sent over the wire.
// --- package this up in the WorkQueuePublisher ---
let ctrl = serde_json::to_vec(&control_message).unwrap();
let data = serde_json::to_vec(&request).unwrap();
let ctrl = match serde_json::to_vec(&control_message) {
Ok(ctrl) => ctrl,
Err(err) => {
anyhow::bail!("Failed serializing RequestControlMessage to JSON array: {err}");
}
};
let data = match serde_json::to_vec(&request) {
Ok(data) => data,
Err(err) => {
anyhow::bail!("Failed serializing request to JSON array: {err}");
}
};
tracing::trace!(
"[req: {}] packaging two-part message; ctrl: {} bytes, data: {} bytes",
......@@ -143,7 +153,7 @@ where
// or it should take a two part message directly
// todo - update this
let codec = TwoPartCodec::default();
let buffer = codec.encode_message(msg).unwrap();
let buffer = codec.encode_message(msg)?;
// TRANSPORT ABSTRACT REQUIRED - END HERE
......@@ -164,9 +174,15 @@ where
let stream = tokio_stream::wrappers::ReceiverStream::new(response_stream.rx);
let stream = stream.map(|msg| {
let resp: U = serde_json::from_slice(&msg).unwrap();
resp
let stream = stream.filter_map(|msg| async move {
match serde_json::from_slice::<U>(&msg) {
Ok(r) => Some(r),
Err(err) => {
let json_str = String::from_utf8_lossy(&msg);
tracing::error!(%err, %json_str, "Failed deserializing JSON to response");
None
}
}
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
......
......@@ -38,7 +38,15 @@ where
header.len(),
data.len()
);
let control_msg: RequestControlMessage = serde_json::from_slice(&header).unwrap();
let control_msg: RequestControlMessage = match serde_json::from_slice(&header) {
Ok(cm) => cm,
Err(err) => {
let json_str = String::from_utf8_lossy(&header);
return Err(PipelineError::DeserializationError(
format!("Failed deserializing to RequestControlMessage. err={err}, json_str={json_str}"),
));
}
};
let request: T = serde_json::from_slice(&data)?;
(control_msg, request)
}
......
......@@ -16,6 +16,7 @@
use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use tokio::io::{ReadHalf, WriteHalf};
use tokio::{io::AsyncWriteExt, net::TcpStream};
use tokio_util::codec::{FramedRead, FramedWrite};
......@@ -82,7 +83,7 @@ impl TcpClient {
let stream = TcpClient::connect(&info.address).await?;
let (read_half, write_half) = tokio::io::split(stream);
let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
// this is a oneshot channel that will be used to signal when the stream is closed
......@@ -90,53 +91,9 @@ impl TcpClient {
// the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
// so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
// captured by the monitor task
let (mut alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
// monitors the channel for a cancellation signal
// this task exits when the alive_rx half of the oneshot channel is closed or a stop/kill signal is received
tokio::spawn(async move {
loop {
tokio::select! {
msg = framed_reader.next() => {
match msg {
Some(Ok(two_part_msg)) => {
match two_part_msg.optional_parts() {
(Some(bytes), None) => {
let msg: ControlMessage = serde_json::from_slice(bytes).unwrap();
match msg {
ControlMessage::Stop => {
context.stop();
break;
}
ControlMessage::Kill => {
context.kill();
break;
}
}
}
_ => {
// we should not receive this
}
}
}
Some(Err(e)) => {
panic!("failed to decode message from stream: {:?}", e);
// break;
}
None => {
// the stream was closed, we should stop the stream
return;
}
}
}
_ = alive_tx.closed() => {
// the channel was closed, we should stop the stream
break;
}
}
}
// framed_writer.get_mut().shutdown().await.unwrap();
});
let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
tokio::spawn(control_handler(framed_reader, context, alive_tx));
// transport specific handshake message
let handshake = CallHomeHandshake {
......@@ -144,7 +101,14 @@ impl TcpClient {
stream_type: StreamType::Response,
};
let handshake_bytes = serde_json::to_vec(&handshake).unwrap();
let handshake_bytes = match serde_json::to_vec(&handshake) {
Ok(hb) => hb,
Err(err) => {
return Err(format!(
"create_response_steam: Error converting CallHomeHandshake to JSON array: {err:#}"
));
}
};
let msg = TwoPartMessage::from_header(handshake_bytes.into());
// issue the the first tcp handshake message
......@@ -154,26 +118,10 @@ impl TcpClient {
.map_err(|e| format!("failed to send handshake: {:?}", e))?;
// set up the channel to send bytes to the transport layer
let (bytes_tx, mut bytes_rx) = tokio::sync::mpsc::channel(16);
let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(16);
// forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
tokio::spawn(async move {
while let Some(msg) = bytes_rx.recv().await {
if let Err(e) = framed_writer.send(msg).await {
tracing::trace!(
"failed to send message to stream; possible disconnect: {:?}",
e
);
// TODO - possibly propagate the error upstream
break;
}
}
drop(alive_rx);
if let Err(e) = framed_writer.get_mut().shutdown().await {
tracing::trace!("failed to shutdown writer: {:?}", e);
}
});
tokio::spawn(forward_handler(bytes_rx, framed_writer, alive_rx));
// set up the prologue for the stream
// this might have transport specific metadata in the future
......@@ -188,3 +136,80 @@ impl TcpClient {
Ok(stream_sender)
}
}
/// monitors the channel for a cancellation signal
/// this task exits when the alive_rx half of the oneshot channel is closed or a stop/kill signal is received
async fn control_handler(
mut framed_reader: FramedRead<ReadHalf<TcpStream>, TwoPartCodec>,
context: Arc<dyn AsyncEngineContext>,
mut alive_tx: tokio::sync::oneshot::Sender<()>,
) {
loop {
tokio::select! {
msg = framed_reader.next() => {
match msg {
Some(Ok(two_part_msg)) => {
match two_part_msg.optional_parts() {
(Some(bytes), None) => {
let msg: ControlMessage = match serde_json::from_slice(bytes) {
Ok(msg) => msg,
Err(err) => {
let json_str = String::from_utf8_lossy(bytes);
tracing::error!(%err, %json_str, "control_handler fatal error deserializing JSON to ControlMessage");
context.kill();
break;
}
};
match msg {
ControlMessage::Stop => {
context.stop();
break;
}
ControlMessage::Kill => {
context.kill();
break;
}
}
}
_ => {
// we should not receive this
}
}
}
Some(Err(e)) => {
panic!("failed to decode message from stream: {:?}", e);
// break;
}
None => {
// the stream was closed, we should stop the stream
return;
}
}
}
_ = alive_tx.closed() => {
// the channel was closed, we should stop the stream
break;
}
}
}
// framed_writer.get_mut().shutdown().await.unwrap();
}
async fn forward_handler(
mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
mut framed_writer: FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>,
alive_rx: tokio::sync::oneshot::Receiver<()>,
) {
while let Some(msg) = bytes_rx.recv().await {
if let Err(e) = framed_writer.send(msg).await {
tracing::trace!(%e, "failed to send message to stream; possible disconnect");
// TODO - possibly propagate the error upstream
break;
}
}
drop(alive_rx);
if let Err(e) = framed_writer.get_mut().shutdown().await {
tracing::trace!("failed to shutdown writer: {:?}", e);
}
}
......@@ -298,7 +298,14 @@ async fn tcp_listener(
};
loop {
let (stream, _addr) = listener.accept().await.unwrap();
let (stream, _addr) = match listener.accept().await {
Ok(x) => x,
Err(err) => {
// TODO: Probably this is normal, user Ctrl-C something like that, find out
tracing::info!(%err, addr, "TCP listener closed");
break;
}
};
stream.set_nodelay(true).unwrap();
tokio::spawn(handle_connection(stream, state.clone()));
}
......@@ -336,18 +343,16 @@ async fn tcp_listener(
// we await on the raw bytes which should come in as a header only message
// todo - improve error handling - check for no data
if first_message.header().is_none() {
return Err("Expected ControlMessage, got DataMessage".to_string());
}
// deserialize the [`CallHomeHandshake`] message
let handshake: CallHomeHandshake = serde_json::from_slice(first_message.header().unwrap())
.map_err(|e| {
let handshake: CallHomeHandshake = match first_message.header() {
Some(header) => serde_json::from_slice(header).map_err(|e| {
format!(
"Failed to deserialize the first message as a valid `CallHomeHandshake`: {}",
e
"Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
)
})?;
})?,
None => {
return Err("Expected ControlMessage, got DataMessage".to_string());
}
};
// branch here to handle sender stream or receiver stream
match handshake.stream_type {
......@@ -357,7 +362,7 @@ async fn tcp_listener(
.await
}
}
.map_err(|e| format!("Failed to process stream: {}", e))
.map_err(|e| format!("Failed to process stream: {e}"))
}
async fn process_request_stream() -> Result<(), String> {
......@@ -374,7 +379,7 @@ async fn tcp_listener(
.lock().await
.rx_subjects
.remove(&subject)
.ok_or(format!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
.ok_or(format!("Subject not found: {subject}; upstream publisher specified a subject unknown to the downsteam subscriber"))?;
// unwrap response_stream
let RequestedRecvConnection {
......@@ -475,7 +480,9 @@ async fn tcp_listener(
}
if !data.is_empty() {
response_tx.send(data).await.unwrap();
if let Err(err) = response_tx.send(data).await {
return Err(format!("handle_response_stream: Failed sending to response_tx: {err}"));
};
}
}
Some(Err(e)) => {
......@@ -541,6 +548,9 @@ async fn tcp_listener(
}
}
let mut framed_writer = _socket_tx;
framed_writer.get_mut().shutdown().await.unwrap();
if let Err(err) = framed_writer.get_mut().shutdown().await {
// TODO: This might be fine to ignore
tracing::error!("monitor shutdown error: {err}");
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment