client.rs 10.9 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
Ryan Olson's avatar
Ryan Olson committed
3
4
5
6

use std::sync::Arc;

use futures::{SinkExt, StreamExt};
7
8
9
10
11
12
use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
use tokio::{
    io::AsyncWriteExt,
    net::TcpStream,
    time::{self, Duration, Instant},
};
Ryan Olson's avatar
Ryan Olson committed
13
14
15
16
17
use tokio_util::codec::{FramedRead, FramedWrite};

use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
use crate::engine::AsyncEngineContext;
use crate::pipeline::network::{
18
    ConnectionInfo, ResponseStreamPrologue, StreamSender,
Ryan Olson's avatar
Ryan Olson committed
19
20
    codec::{TwoPartCodec, TwoPartMessage},
    tcp::StreamType,
21
};
22
use crate::{ErrorContext, Result, error}; // Import SinkExt to use the `send` method
Ryan Olson's avatar
Ryan Olson committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

#[allow(dead_code)]
pub struct TcpClient {
    worker_id: String,
}

impl Default for TcpClient {
    fn default() -> Self {
        TcpClient {
            worker_id: uuid::Uuid::new_v4().to_string(),
        }
    }
}

impl TcpClient {
    pub fn new(worker_id: String) -> Self {
        TcpClient { worker_id }
    }

42
    async fn connect(address: &str) -> std::io::Result<TcpStream> {
43
        // try to connect to the address; retry with linear backoff if AddrNotAvailable
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        let backoff = std::time::Duration::from_millis(200);
        loop {
            match TcpStream::connect(address).await {
                Ok(socket) => {
                    socket.set_nodelay(true)?;
                    return Ok(socket);
                }
                Err(e) => {
                    if e.kind() == std::io::ErrorKind::AddrNotAvailable {
                        tracing::warn!("retry warning: failed to connect: {:?}", e);
                        tokio::time::sleep(backoff).await;
                    } else {
                        return Err(e);
                    }
                }
            }
        }
Ryan Olson's avatar
Ryan Olson committed
61
62
    }

63
    pub async fn create_response_stream(
Ryan Olson's avatar
Ryan Olson committed
64
65
        context: Arc<dyn AsyncEngineContext>,
        info: ConnectionInfo,
66
67
68
    ) -> Result<StreamSender> {
        let info =
            TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
Ryan Olson's avatar
Ryan Olson committed
69
70
71
        tracing::trace!("Creating response stream for {:?}", info);

        if info.stream_type != StreamType::Response {
72
            return Err(error!(
Ryan Olson's avatar
Ryan Olson committed
73
74
75
76
77
78
                "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
                info.stream_type
            ));
        }

        if info.context != context.id() {
79
            return Err(error!(
Ryan Olson's avatar
Ryan Olson committed
80
81
82
83
84
85
86
87
88
                "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
                context.id(),
                info.context
            ));
        }

        let stream = TcpClient::connect(&info.address).await?;
        let (read_half, write_half) = tokio::io::split(stream);

89
        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
Ryan Olson's avatar
Ryan Olson committed
90
91
92
93
94
95
96
        let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());

        // this is a oneshot channel that will be used to signal when the stream is closed
        // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
        // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
        // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
        // captured by the monitor task
97
98
        let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();

99
        let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
Ryan Olson's avatar
Ryan Olson committed
100
101
102
103
104
105
106

        // transport specific handshake message
        let handshake = CallHomeHandshake {
            subject: info.subject,
            stream_type: StreamType::Response,
        };

107
108
109
        let handshake_bytes = match serde_json::to_vec(&handshake) {
            Ok(hb) => hb,
            Err(err) => {
110
                return Err(error!(
111
                    "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
112
113
114
                ));
            }
        };
Ryan Olson's avatar
Ryan Olson committed
115
116
117
118
119
120
        let msg = TwoPartMessage::from_header(handshake_bytes.into());

        // issue the the first tcp handshake message
        framed_writer
            .send(msg)
            .await
121
            .map_err(|e| error!("failed to send handshake: {:?}", e))?;
Ryan Olson's avatar
Ryan Olson committed
122
123

        // set up the channel to send bytes to the transport layer
124
        let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
Ryan Olson's avatar
Ryan Olson committed
125
126

        // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
127

128
        let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
129
130
131
132
133
134
135
136
137

        tokio::spawn(async move {
            // await both tasks
            let (reader, writer) = tokio::join!(reader_task, writer_task);

            match (reader, writer) {
                (Ok(reader), Ok(writer)) => {
                    let reader = reader.into_inner();

138
139
140
141
142
143
144
145
                    let writer = match writer {
                        Ok(writer) => writer.into_inner(),
                        Err(e) => {
                            tracing::error!("failed to join writer task: {:?}", e);
                            return Err(e);
                        }
                    };

146
147
                    let mut stream = reader.unsplit(writer);

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                    // await the tcp server to shutdown the socket connection
                    // set a timeout for the server shutdown
                    let mut buf = vec![0u8; 1024];
                    let deadline = Instant::now() + Duration::from_secs(10);
                    loop {
                        let n = time::timeout_at(deadline, stream.read(&mut buf))
                            .await
                            .inspect_err(|_| {
                                tracing::debug!("server did not close socket within the deadline");
                            })?
                            .inspect_err(|e| {
                                tracing::debug!("failed to read from stream: {:?}", e);
                            })?;
                        if n == 0 {
                            // Server has closed (FIN)
                            break;
                        }
                    }

                    Ok(())
168
169
170
171
172
173
174
                }
                _ => {
                    tracing::error!("failed to join reader and writer tasks");
                    anyhow::bail!("failed to join reader and writer tasks");
                }
            }
        });
Ryan Olson's avatar
Ryan Olson committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188

        // set up the prologue for the stream
        // this might have transport specific metadata in the future
        let prologue = Some(ResponseStreamPrologue { error: None });

        // create the stream sender
        let stream_sender = StreamSender {
            tx: bytes_tx,
            prologue,
        };

        Ok(stream_sender)
    }
}
189

190
191
async fn handle_reader(
    framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
192
    context: Arc<dyn AsyncEngineContext>,
193
194
195
196
    alive_tx: tokio::sync::oneshot::Sender<()>,
) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
    let mut framed_reader = framed_reader;
    let mut alive_tx = alive_tx;
197
198
199
200
201
202
203
    loop {
        tokio::select! {
            msg = framed_reader.next() => {
                match msg {
                    Some(Ok(two_part_msg)) => {
                        match two_part_msg.optional_parts() {
                           (Some(bytes), None) => {
204
                                let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
205
                                    Ok(msg) => msg,
206
207
                                    Err(_) => {
                                        // TODO(#171) - address fatal errors
208
                                        panic!("fatal error - invalid control message detected");
209
210
                                    }
                                };
211

212
213
214
215
216
217
                                match msg {
                                    ControlMessage::Stop => {
                                        context.stop();
                                    }
                                    ControlMessage::Kill => {
                                        context.kill();
218
219
220
221
                                    }
                                    ControlMessage::Sentinel => {
                                        // TODO(#171) - address fatal errors
                                        panic!("received a sentinel message; this should never happen");
222
223
224
225
                                    }
                                }
                           }
                           _ => {
226
                                panic!("received a non-control message; this should never happen");
227
228
229
                           }
                        }
                    }
230
231
232
                    Some(Err(_)) => {
                        // TODO(#171) - address fatal errors
                        // in this case the binary representation of the message is invalid
233
                        panic!("fatal error - failed to decode message from stream; invalid line protocol");
234
235
                    }
                    None => {
236
                        tracing::debug!("tcp stream closed by server");
237
                        break;
238
239
240
241
242
243
244
245
                    }
                }
            }
            _ = alive_tx.closed() => {
                break;
            }
        }
    }
246
    framed_reader
247
248
}

249
250
async fn handle_writer(
    mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
251
252
    mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
    alive_rx: tokio::sync::oneshot::Receiver<()>,
253
254
255
256
257
258
259
260
261
262
263
    context: Arc<dyn AsyncEngineContext>,
) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
    loop {
        let msg = tokio::select! {
            biased;

            _ = context.killed() => {
                tracing::trace!("context kill signal received; shutting down");
                break;
            }

264
265
266
267
268
            _ = context.stopped() => {
                tracing::trace!("context stop signal received; shutting down");
                break;
            }

269
270
271
272
273
274
275
276
277
278
279
            msg = bytes_rx.recv() => {
                match msg {
                    Some(msg) => msg,
                    None => {
                        tracing::trace!("response channel closed; shutting down");
                        break;
                    }
                }
            }
        };

280
        if let Err(e) = framed_writer.send(msg).await {
281
            tracing::trace!(
282
                "failed to send message to network; possible disconnect: {:?}",
283
284
                e
            );
285
286
287
            break;
        }
    }
288

289
290
291
292
293
294
295
    // send sentinel message
    let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
    let msg = TwoPartMessage::from_header(message.into());
    framed_writer.send(msg).await?;

    drop(alive_rx);
    Ok(framed_writer)
296
}