Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
5ed8c1c0
Commit
5ed8c1c0
authored
Feb 03, 2025
by
Ryan Olson
Committed by
GitHub
Feb 03, 2025
Browse files
feat: rust - initial commit
the journey begins
parent
4017bd18
Changes
54
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
3012 additions
and
0 deletions
+3012
-0
runtime/rust/src/service.rs
runtime/rust/src/service.rs
+182
-0
runtime/rust/src/transports.rs
runtime/rust/src/transports.rs
+24
-0
runtime/rust/src/transports/etcd.rs
runtime/rust/src/transports/etcd.rs
+283
-0
runtime/rust/src/transports/etcd/kv.rs
runtime/rust/src/transports/etcd/kv.rs
+16
-0
runtime/rust/src/transports/etcd/lease.rs
runtime/rust/src/transports/etcd/lease.rs
+118
-0
runtime/rust/src/transports/nats.rs
runtime/rust/src/transports/nats.rs
+924
-0
runtime/rust/src/transports/nats/slug.rs
runtime/rust/src/transports/nats/slug.rs
+164
-0
runtime/rust/src/transports/tcp.rs
runtime/rust/src/transports/tcp.rs
+17
-0
runtime/rust/src/worker.rs
runtime/rust/src/worker.rs
+208
-0
runtime/rust/tests/common/engines.rs
runtime/rust/tests/common/engines.rs
+229
-0
runtime/rust/tests/common/mock.rs
runtime/rust/tests/common/mock.rs
+497
-0
runtime/rust/tests/common/mod.rs
runtime/rust/tests/common/mod.rs
+2
-0
runtime/rust/tests/lifecycle.rs
runtime/rust/tests/lifecycle.rs
+58
-0
runtime/rust/tests/pipeline.rs
runtime/rust/tests/pipeline.rs
+290
-0
No files found.
runtime/rust/src/service.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
// TODO - refactor this entire module
//
// we want to carry forward the concept of live vs ready for the components
// we will want to associate the components cancellation token with the
// component's "service state"
use
crate
::{
log
,
transports
::
nats
,
Result
};
use
async_nats
::
Message
;
use
async_stream
::
try_stream
;
use
bytes
::
Bytes
;
use
derive_getters
::
Dissolve
;
use
futures
::
stream
::
StreamExt
;
use
serde
::{
de
::
DeserializeOwned
,
Deserialize
,
Serialize
};
use
std
::
time
::
Duration
;
pub
struct
ServiceClient
{
nats_client
:
nats
::
Client
,
}
impl
ServiceClient
{
#[allow(dead_code)]
pub
(
crate
)
fn
new
(
nats_client
:
nats
::
Client
)
->
Self
{
ServiceClient
{
nats_client
}
}
}
pub
struct
ServiceSet
{
services
:
Vec
<
ServiceInfo
>
,
}
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
pub
struct
ServiceInfo
{
pub
name
:
String
,
pub
id
:
String
,
pub
version
:
String
,
pub
started
:
String
,
pub
endpoints
:
Vec
<
EndpointInfo
>
,
}
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Dissolve)]
pub
struct
EndpointInfo
{
pub
name
:
String
,
pub
subject
:
String
,
#[serde(flatten)]
pub
data
:
Metrics
,
}
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Dissolve)]
pub
struct
Metrics
(
pub
serde_json
::
Value
);
impl
Metrics
{
pub
fn
decode
<
T
:
DeserializeOwned
>
(
self
)
->
Result
<
T
>
{
serde_json
::
from_value
(
self
.0
)
.map_err
(
Into
::
into
)
}
}
impl
ServiceClient
{
pub
async
fn
unary
(
&
self
,
subject
:
impl
Into
<
String
>
,
payload
:
impl
Into
<
Bytes
>
,
)
->
Result
<
Message
>
{
let
response
=
self
.nats_client
.client
()
.request
(
subject
.into
(),
payload
.into
())
.await
?
;
Ok
(
response
)
}
pub
async
fn
collect_services
(
&
self
,
service_name
:
&
str
)
->
Result
<
ServiceSet
>
{
let
mut
sub
=
self
.nats_client
.service_subscriber
(
service_name
)
.await
?
;
let
deadline
=
tokio
::
time
::
Instant
::
now
()
+
Duration
::
from_secs
(
1
);
let
services
:
Vec
<
Result
<
ServiceInfo
>>
=
try_stream!
{
while
let
Ok
(
Some
(
message
))
=
tokio
::
time
::
timeout_at
(
deadline
,
sub
.next
())
.await
{
if
message
.payload
.is_empty
()
{
continue
;
}
let
service
=
serde_json
::
from_slice
::
<
ServiceInfo
>
(
&
message
.payload
)
?
;
log
::
trace!
(
"service: {:?}"
,
service
);
yield
service
;
}
}
.collect
()
.await
;
// split ok and error results
let
(
ok
,
err
):
(
Vec
<
_
>
,
Vec
<
_
>
)
=
services
.into_iter
()
.partition
(
Result
::
is_ok
);
if
!
err
.is_empty
()
{
log
::
error!
(
"failed to collect services: {:?}"
,
err
);
}
Ok
(
ServiceSet
{
services
:
ok
.into_iter
()
.map
(
Result
::
unwrap
)
.collect
(),
})
}
}
impl
ServiceSet
{
pub
fn
into_endpoints
(
self
)
->
impl
Iterator
<
Item
=
EndpointInfo
>
{
self
.services
.into_iter
()
.flat_map
(|
s
|
s
.endpoints
.into_iter
())
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_service_set
()
{
let
services
=
vec!
[
ServiceInfo
{
name
:
"service1"
.to_string
(),
id
:
"1"
.to_string
(),
version
:
"1.0"
.to_string
(),
started
:
"2021-01-01"
.to_string
(),
endpoints
:
vec!
[
EndpointInfo
{
name
:
"endpoint1"
.to_string
(),
subject
:
"subject1"
.to_string
(),
data
:
Metrics
(
serde_json
::
json!
({
"key"
:
"value1"
})),
},
EndpointInfo
{
name
:
"endpoint2-foo"
.to_string
(),
subject
:
"subject2"
.to_string
(),
data
:
Metrics
(
serde_json
::
json!
({
"key"
:
"value1"
})),
},
],
},
ServiceInfo
{
name
:
"service1"
.to_string
(),
id
:
"2"
.to_string
(),
version
:
"1.0"
.to_string
(),
started
:
"2021-01-01"
.to_string
(),
endpoints
:
vec!
[
EndpointInfo
{
name
:
"endpoint1"
.to_string
(),
subject
:
"subject1"
.to_string
(),
data
:
Metrics
(
serde_json
::
json!
({
"key"
:
"value1"
})),
},
EndpointInfo
{
name
:
"endpoint2-bar"
.to_string
(),
subject
:
"subject2"
.to_string
(),
data
:
Metrics
(
serde_json
::
json!
({
"key"
:
"value2"
})),
},
],
},
];
let
service_set
=
ServiceSet
{
services
};
let
endpoints
:
Vec
<
_
>
=
service_set
.into_endpoints
()
.filter
(|
e
|
e
.name
.starts_with
(
"endpoint2"
))
.collect
();
assert_eq!
(
endpoints
.len
(),
2
);
}
}
runtime/rust/src/transports.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! The Transports module hosts all the network communication stacks used for talking
//! to services or moving data around the network.
//!
//! These are the low-level building blocks for the distributed system.
pub
mod
etcd
;
pub
mod
nats
;
pub
mod
tcp
;
runtime/rust/src/transports/etcd.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use
crate
::{
error
,
log
,
CancellationToken
,
ErrorContext
,
Result
,
Runtime
};
use
async_nats
::
jetstream
::
kv
;
use
derive_builder
::
Builder
;
use
derive_getters
::
Dissolve
;
use
futures
::
StreamExt
;
use
tokio
::
sync
::
mpsc
;
use
validator
::
Validate
;
use
etcd_client
::{
Compare
,
CompareOp
,
GetOptions
,
KeyValue
,
PutOptions
,
Txn
,
TxnOp
,
WatchOptions
,
Watcher
,
};
pub
use
etcd_client
::{
ConnectOptions
,
LeaseClient
};
mod
lease
;
use
lease
::
*
;
//pub use etcd::ConnectOptions as EtcdConnectOptions;
/// ETCD Client
#[derive(Clone)]
pub
struct
Client
{
client
:
etcd_client
::
Client
,
primary_lease
:
i64
,
runtime
:
Runtime
,
}
#[derive(Debug,
Clone)]
pub
struct
Lease
{
/// ETCD lease ID
id
:
i64
,
/// [`CancellationToken`] associated with the lease
cancel_token
:
CancellationToken
,
}
impl
Lease
{
/// Get the lease ID
pub
fn
id
(
&
self
)
->
i64
{
self
.id
}
/// Get the primary [`CancellationToken`] associated with the lease.
/// This token will revoke the lease if canceled.
pub
fn
primary_token
(
&
self
)
->
CancellationToken
{
self
.cancel_token
.clone
()
}
/// Get a child [`CancellationToken`] from the lease's [`CancellationToken`].
/// This child token will be triggered if the lease is revoked, but will not revoke the lease if canceled.
pub
fn
child_token
(
&
self
)
->
CancellationToken
{
self
.cancel_token
.child_token
()
}
/// Revoke the lease triggering the [`CancellationToken`].
pub
fn
revoke
(
&
self
)
{
self
.cancel_token
.cancel
();
}
}
impl
Client
{
pub
fn
builder
()
->
ClientOptionsBuilder
{
ClientOptionsBuilder
::
default
()
}
/// Create a new discovery client
///
/// This will establish a connection to the etcd server, create a primary lease,
/// and spawn a task to keep the lease alive and tie the lifetime of the [`Runtime`]
/// to the lease.
///
/// If the lease expires, the [`Runtime`] will be shutdown.
/// If the [`Runtime`] is shutdown, the lease will be revoked.
pub
async
fn
new
(
config
:
ClientOptions
,
runtime
:
Runtime
)
->
Result
<
Self
>
{
runtime
.secondary
()
.spawn
(
Self
::
create
(
config
,
runtime
.clone
()))
.await
?
}
/// Create a new etcd client and tie the primary [`CancellationToken`] to the primary etcd lease.
async
fn
create
(
config
:
ClientOptions
,
runtime
:
Runtime
)
->
Result
<
Self
>
{
let
token
=
runtime
.primary_token
();
let
client
=
etcd_client
::
Client
::
connect
(
config
.etcd_url
,
config
.etcd_connect_options
)
.await
?
;
let
lease_client
=
client
.lease_client
();
let
lease
=
create_lease
(
lease_client
,
10
,
token
)
.await
.context
(
"creating primary lease"
)
?
;
Ok
(
Client
{
client
,
primary_lease
:
lease
.id
,
runtime
,
})
}
/// Get a reference to the underlying [`etcd_client::Client`] instance.
pub
fn
etcd_client
(
&
self
)
->
&
etcd_client
::
Client
{
&
self
.client
}
/// Get the primary lease ID.
pub
fn
lease_id
(
&
self
)
->
i64
{
self
.primary_lease
}
/// Primary [`Lease`]
pub
fn
primary_lease
(
&
self
)
->
Lease
{
Lease
{
id
:
self
.primary_lease
,
cancel_token
:
self
.runtime
.primary_token
(),
}
}
/// Create a [`Lease`] with a given time-to-live (TTL).
/// This [`Lease`] will be tied to the [`Runtime`], specifically a child [`CancellationToken`].
pub
async
fn
create_lease
(
&
self
,
ttl
:
i64
)
->
Result
<
Lease
>
{
let
token
=
self
.runtime
.child_token
();
let
lease_client
=
self
.client
.lease_client
();
self
.runtime
.secondary
()
.spawn
(
create_lease
(
lease_client
,
ttl
,
token
))
.await
?
}
pub
async
fn
kv_create
(
&
self
,
key
:
String
,
value
:
Vec
<
u8
>
,
lease_id
:
Option
<
i64
>
,
)
->
Result
<
()
>
{
let
put_options
=
lease_id
.map
(|
id
|
PutOptions
::
new
()
.with_lease
(
id
));
// Build the transaction
let
txn
=
Txn
::
new
()
.when
(
vec!
[
Compare
::
version
(
key
.as_str
(),
CompareOp
::
Equal
,
0
)])
// Ensure the lock does not exist
.and_then
(
vec!
[
TxnOp
::
put
(
key
.as_str
(),
value
,
put_options
),
// Create the object
]);
// Execute the transaction
let
_
=
self
.client
.kv_client
()
.txn
(
txn
)
.await
?
;
Ok
(())
}
pub
async
fn
kv_get_prefix
(
&
self
,
prefix
:
impl
AsRef
<
str
>
)
->
Result
<
Vec
<
KeyValue
>>
{
let
mut
get_response
=
self
.client
.kv_client
()
.get
(
prefix
.as_ref
(),
Some
(
GetOptions
::
new
()
.with_prefix
()))
.await
?
;
Ok
(
get_response
.take_kvs
())
}
pub
async
fn
kv_get_and_watch_prefix
(
&
self
,
prefix
:
impl
AsRef
<
str
>
)
->
Result
<
PrefixWatcher
>
{
let
mut
kv_client
=
self
.client
.kv_client
();
let
mut
watch_client
=
self
.client
.watch_client
();
let
mut
get_response
=
kv_client
.get
(
prefix
.as_ref
(),
Some
(
GetOptions
::
new
()
.with_prefix
()))
.await
?
;
let
start_revision
=
get_response
.header
()
.ok_or
(
error!
(
"missing header; unable to get revision"
))
?
.revision
();
let
(
watcher
,
mut
watch_stream
)
=
watch_client
.watch
(
prefix
.as_ref
(),
Some
(
WatchOptions
::
new
()
.with_prefix
()
.with_start_revision
(
start_revision
),
),
)
.await
?
;
let
kvs
=
get_response
.take_kvs
();
let
(
tx
,
rx
)
=
mpsc
::
channel
(
32
);
self
.runtime
.secondary
()
.spawn
(
async
move
{
for
kv
in
kvs
{
if
tx
.send
(
WatchEvent
::
Put
(
kv
))
.await
.is_err
()
{
// receiver is closed
break
;
}
}
while
let
Some
(
Ok
(
response
))
=
watch_stream
.next
()
.await
{
for
event
in
response
.events
()
{
match
event
.event_type
()
{
etcd_client
::
EventType
::
Put
=>
{
if
let
Some
(
kv
)
=
event
.kv
()
{
if
tx
.send
(
WatchEvent
::
Put
(
kv
.clone
()))
.await
.is_err
()
{
// receiver is closed
break
;
}
}
}
etcd_client
::
EventType
::
Delete
=>
{
if
let
Some
(
kv
)
=
event
.kv
()
{
if
tx
.send
(
WatchEvent
::
Delete
(
kv
.clone
()))
.await
.is_err
()
{
// receiver is closed
break
;
}
}
}
}
}
}
});
Ok
(
PrefixWatcher
{
prefix
:
prefix
.as_ref
()
.to_string
(),
watcher
,
rx
,
})
}
}
#[derive(Dissolve)]
pub
struct
PrefixWatcher
{
prefix
:
String
,
watcher
:
Watcher
,
rx
:
mpsc
::
Receiver
<
WatchEvent
>
,
}
pub
enum
WatchEvent
{
Put
(
KeyValue
),
Delete
(
KeyValue
),
}
/// ETCD client configuration options
#[derive(Debug,
Clone,
Builder,
Validate)]
pub
struct
ClientOptions
{
#[validate(length(min
=
1
))]
etcd_url
:
Vec
<
String
>
,
#[builder(default)]
etcd_connect_options
:
Option
<
ConnectOptions
>
,
}
impl
Default
for
ClientOptions
{
fn
default
()
->
Self
{
ClientOptions
{
etcd_url
:
default_servers
(),
etcd_connect_options
:
None
,
}
}
}
fn
default_servers
()
->
Vec
<
String
>
{
match
std
::
env
::
var
(
"ETCD_ENDPOINTS"
)
{
Ok
(
possible_list_of_urls
)
=>
possible_list_of_urls
.split
(
','
)
.map
(|
s
|
s
.to_string
())
.collect
(),
Err
(
_
)
=>
vec!
[
"http://localhost:2379"
.to_string
()],
}
}
runtime/rust/src/transports/etcd/kv.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
runtime/rust/src/transports/etcd/lease.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use
super
::
*
;
/// Create a [`Lease`] with a given time-to-live (TTL) attached to the [`CancellationToken`].
pub
async
fn
create_lease
(
mut
lease_client
:
LeaseClient
,
ttl
:
i64
,
token
:
CancellationToken
,
)
->
Result
<
Lease
>
{
let
lease
=
lease_client
.grant
(
ttl
,
None
)
.await
?
;
let
id
=
lease
.id
();
let
ttl
=
lease
.ttl
();
let
child
=
token
.child_token
();
let
clone
=
token
.clone
();
tokio
::
spawn
(
async
move
{
match
keep_alive
(
lease_client
,
id
,
ttl
,
child
)
.await
{
Ok
(
_
)
=>
log
::
trace!
(
"keep alive task exited successfully"
),
Err
(
e
)
=>
{
log
::
info!
(
"keep alive task failed: {:?}"
,
e
);
token
.cancel
();
}
}
});
Ok
(
Lease
{
id
,
cancel_token
:
clone
,
})
}
/// Task to keep leases alive.
///
/// If this task returns an error, the cancellation token will be invoked on the runtime.
/// If
pub
async
fn
keep_alive
(
client
:
LeaseClient
,
lease_id
:
i64
,
ttl
:
i64
,
token
:
CancellationToken
,
)
->
Result
<
()
>
{
let
mut
ttl
=
ttl
;
let
mut
deadline
=
create_deadline
(
ttl
)
?
;
let
mut
client
=
client
;
let
(
mut
heartbeat_sender
,
mut
heartbeat_receiver
)
=
client
.keep_alive
(
lease_id
)
.await
?
;
loop
{
// if the deadline is exceeded, then we have failed to issue a heartbeat in time
// we maybe be permanently disconnected from the etcd server, so we are now officially done
if
deadline
<
std
::
time
::
Instant
::
now
()
{
return
Err
(
error!
(
"failed to issue heartbeat in time"
));
}
tokio
::
select!
{
biased
;
status
=
heartbeat_receiver
.message
()
=>
{
if
let
Some
(
resp
)
=
status
?
{
log
::
trace!
(
lease_id
,
"keep alive response received: {:?}"
,
resp
);
// update ttl and deadline
ttl
=
resp
.ttl
();
deadline
=
create_deadline
(
ttl
)
?
;
if
resp
.ttl
()
==
0
{
return
Err
(
error!
(
"lease expired or revoked"
));
}
}
}
_
=
token
.cancelled
()
=>
{
log
::
trace!
(
lease_id
,
"cancellation token triggered; revoking lease"
);
let
_
=
client
.revoke
(
lease_id
)
.await
?
;
return
Ok
(());
}
_
=
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_secs
(
ttl
as
u64
/
2
))
=>
{
log
::
trace!
(
lease_id
,
"sending keep alive"
);
// if we get a error issuing the heartbeat, set the ttl to 0
// this will allow us to poll the response stream once and the cancellation token once, then
// immediately try to tick the heartbeat
// this will repeat until either the heartbeat is reestablished or the deadline is exceeded
if
let
Err
(
e
)
=
heartbeat_sender
.keep_alive
()
.await
{
log
::
warn!
(
lease_id
,
"keep alive failed: {:?}"
,
e
);
ttl
=
0
;
}
}
}
}
}
/// Create a deadline for a given time-to-live (TTL).
fn
create_deadline
(
ttl
:
i64
)
->
Result
<
std
::
time
::
Instant
>
{
if
ttl
<=
0
{
return
Err
(
error!
(
"invalid ttl: {}"
,
ttl
));
}
Ok
(
std
::
time
::
Instant
::
now
()
+
std
::
time
::
Duration
::
from_secs
(
ttl
as
u64
))
}
runtime/rust/src/transports/nats.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! NATS transport
//!
//! The following environment variables are used to configure the NATS client:
//!
//! - `NATS_SERVER`: the NATS server address
//!
//! For authentication, the following environment variables are used and prioritized in the following order:
//!
//! - `NATS_AUTH_USERNAME`: the username for authentication
//! - `NATS_AUTH_PASSWORD`: the password for authentication
//! - `NATS_AUTH_TOKEN`: the token for authentication
//! - `NATS_AUTH_NKEY`: the nkey for authentication
//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
//!
//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
use
crate
::
Result
;
use
async_nats
::{
client
,
jetstream
,
Subscriber
};
use
derive_builder
::
Builder
;
use
futures
::
TryStreamExt
;
use
std
::
path
::
PathBuf
;
use
validator
::{
Validate
,
ValidationError
};
mod
slug
;
pub
use
slug
::
Slug
;
#[derive(Clone)]
pub
struct
Client
{
client
:
client
::
Client
,
js_ctx
:
jetstream
::
Context
,
}
impl
Client
{
/// Create a NATS [`ClientOptionsBuilder`].
pub
fn
builder
()
->
ClientOptionsBuilder
{
ClientOptionsBuilder
::
default
()
}
/// Returns a reference to the underlying [`async_nats::client::Client`] instance
pub
fn
client
(
&
self
)
->
&
client
::
Client
{
&
self
.client
}
/// Returns a reference to the underlying [`async_nats::jetstream::Context`] instance
pub
fn
jetstream
(
&
self
)
->
&
jetstream
::
Context
{
&
self
.js_ctx
}
/// fetch the list of streams
pub
async
fn
list_streams
(
&
self
)
->
Result
<
Vec
<
String
>>
{
let
names
=
self
.js_ctx
.stream_names
();
let
stream_names
:
Vec
<
String
>
=
names
.try_collect
()
.await
?
;
Ok
(
stream_names
)
}
/// fetch the list of consumers for a given stream
pub
async
fn
list_consumers
(
&
self
,
stream_name
:
&
str
)
->
Result
<
Vec
<
String
>>
{
let
stream
=
self
.js_ctx
.get_stream
(
stream_name
)
.await
?
;
let
consumers
:
Vec
<
String
>
=
stream
.consumer_names
()
.try_collect
()
.await
?
;
Ok
(
consumers
)
}
pub
async
fn
stream_info
(
&
self
,
stream_name
:
&
str
)
->
Result
<
jetstream
::
stream
::
State
>
{
let
mut
stream
=
self
.js_ctx
.get_stream
(
stream_name
)
.await
?
;
let
info
=
stream
.info
()
.await
?
;
Ok
(
info
.state
.clone
())
}
pub
async
fn
get_stream
(
&
self
,
name
:
&
str
)
->
Result
<
jetstream
::
stream
::
Stream
>
{
let
stream
=
self
.js_ctx
.get_stream
(
name
)
.await
?
;
Ok
(
stream
)
}
pub
async
fn
service_subscriber
(
&
self
,
service_name
:
&
str
)
->
Result
<
Subscriber
>
{
let
subject
=
format!
(
"$SRV.STATS.{}"
,
service_name
);
let
reply_subject
=
format!
(
"_INBOX.{}"
,
nuid
::
next
());
let
subscription
=
self
.client
.subscribe
(
reply_subject
.clone
())
.await
?
;
// Publish the request with the reply-to subject
self
.client
.publish_with_reply
(
subject
,
reply_subject
,
""
.into
())
.await
?
;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// tx.send(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
Ok
(
subscription
)
}
// /// create a new stream
// async fn get_or_create_work_queue_stream(
// &self,
// name: &super::Namespace,
// ) -> Result<jetstream::stream::Stream> {
// let stream = self
// .js_ctx
// .get_or_create_stream(async_nats::jetstream::stream::Config {
// name: name.to_string(),
// retention: async_nats::jetstream::stream::RetentionPolicy::WorkQueue,
// subjects: vec![format!("{name}.>")],
// ..Default::default()
// })
// .await?;
// Ok(stream)
// }
// // get work queue
// pub async fn get_or_create_work_queue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<WorkQueue> {
// let stream = self.get_or_create_work_queue_stream(namespace).await?;
// let consumer_name = single_name(namespace, queue_name);
// let subject_name = subject_name(namespace, queue_name);
// let subject_name = format!("{}.*", subject_name);
// tracing::trace!(
// durable_name = consumer_name,
// filter_subject = subject_name,
// "get_or_create_work_queue"
// );
// let consumer = stream
// .get_or_create_consumer(
// &consumer_name,
// jetstream::consumer::pull::Config {
// durable_name: Some(consumer_name.clone()),
// filter_subject: subject_name,
// ack_policy: jetstream::consumer::AckPolicy::Explicit,
// ..Default::default()
// },
// )
// .await?;
// Ok(WorkQueue::new(consumer))
// }
// pub async fn get_or_create_work_queue_publisher(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<WorkQueuePublisher> {
// let _stream = self.get_or_create_work_queue_stream(namespace).await?;
// let _subject = subject_name(namespace, queue_name);
// Ok(WorkQueuePublisher {
// client: self.clone(),
// namespace: namespace.clone(),
// queue_name: queue_name.clone(),
// })
// }
// pub async fn list_work_queues(
// &self,
// namespace: &super::Namespace,
// ) -> Result<Vec<String>> {
// let stream = self.get_stream(namespace.as_ref()).await?;
// let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
// Ok(consumers)
// }
// /// remove a work queue
// pub async fn remove_work_queue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<()> {
// let stream = self.get_stream(namespace.as_ref()).await?;
// let consumer_name = single_name(namespace, queue_name);
// let consumers = self.list_consumers(namespace.as_ref()).await?;
// if consumers.contains(&consumer_name) {
// stream.delete_consumer(&consumer_name).await?;
// }
// Ok(())
// }
// /// publish a message to a subject
// pub async fn publish(&self, subject: String, msg: Vec<u8>) -> Result<()> {
// self.client.publish(subject, msg.into()).await?;
// Ok(())
// }
// /// subscribe to a subject
// pub async fn subscribe(
// &self,
// subject: String,
// ) -> Result<async_nats::Subscriber> {
// let sub = self.client.subscribe(subject).await?;
// Ok(sub)
// }
// pub async fn enqueue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// payload: Bytes,
// ) -> Result<String> {
// // let mut headers = HeaderMap::new();
// let subject = subject_name(namespace, queue_name);
// let request_id = uuid::Uuid::new_v4().to_string();
// let subject = format!("{}.{}", subject, request_id);
// self.client.publish(subject, payload).await?;
// // self.client
// // .publish_with_headers(subject, headers, payload.into())
// // .await?;
// Ok(request_id)
// }
// pub async fn enqueue_with_id(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// request_id: &str,
// payload: Vec<u8>,
// ) -> Result<()> {
// let subject = subject_name(namespace, queue_name);
// let subject = format!("{}.{}", subject, request_id);
// self.client.publish(subject, payload.into()).await?;
// Ok(())
// }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// pub fn frontend_client(&self, request_id: String) -> SpecializedClient {
// SpecializedClient::new(self.client.clone(), ClientKind::Frontend, request_id)
// }
// pub fn backend_client(&self, request_id: String) -> SpecializedClient {
// SpecializedClient::new(self.client.clone(), ClientKind::Backend, request_id)
// }
}
/// NATS client options
///
/// This object uses the builder pattern with default values that are evaluates
/// from the environment variables if they are not explicitly set by the builder.
#[derive(Debug,
Clone,
Builder,
Validate)]
pub
struct
ClientOptions
{
#[builder(setter(into),
default
=
"default_server()"
)]
#[validate(custom(function
=
"validate_nats_server"
))]
server
:
String
,
#[builder(default)]
auth
:
NatsAuth
,
}
fn
default_server
()
->
String
{
if
let
Ok
(
server
)
=
std
::
env
::
var
(
"NATS_SERVER"
)
{
return
server
;
}
"nats://localhost:4222"
.to_string
()
}
fn
validate_nats_server
(
server
:
&
str
)
->
Result
<
(),
ValidationError
>
{
if
server
.starts_with
(
"nats://"
)
{
Ok
(())
}
else
{
Err
(
ValidationError
::
new
(
"server must start with 'nats://'"
))
}
}
#[allow(dead_code)]
impl
ClientOptions
{
/// Create a new [`ClientOptionsBuilder`]
pub
fn
builder
()
->
ClientOptionsBuilder
{
ClientOptionsBuilder
::
default
()
}
/// Validate the config and attempt to connection to the NATS server
pub
async
fn
connect
(
self
)
->
Result
<
Client
>
{
self
.validate
()
?
;
let
client
=
match
self
.auth
{
NatsAuth
::
UserPass
(
username
,
password
)
=>
{
async_nats
::
ConnectOptions
::
with_user_and_password
(
username
,
password
)
}
NatsAuth
::
Token
(
token
)
=>
async_nats
::
ConnectOptions
::
with_token
(
token
),
NatsAuth
::
NKey
(
nkey
)
=>
async_nats
::
ConnectOptions
::
with_nkey
(
nkey
),
NatsAuth
::
CredentialsFile
(
path
)
=>
{
async_nats
::
ConnectOptions
::
with_credentials_file
(
path
)
.await
?
}
};
let
client
=
client
.connect
(
self
.server
)
.await
?
;
let
js_ctx
=
jetstream
::
new
(
client
.clone
());
Ok
(
Client
{
client
,
js_ctx
})
}
}
impl
Default
for
ClientOptions
{
fn
default
()
->
Self
{
ClientOptions
{
server
:
default_server
(),
auth
:
NatsAuth
::
default
(),
}
}
}
#[derive(Clone,
Eq,
PartialEq)]
pub
enum
NatsAuth
{
UserPass
(
String
,
String
),
Token
(
String
),
NKey
(
String
),
CredentialsFile
(
PathBuf
),
}
impl
std
::
fmt
::
Debug
for
NatsAuth
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
match
self
{
NatsAuth
::
UserPass
(
user
,
_
pass
)
=>
{
write!
(
f
,
"UserPass({}, <redacted>)"
,
user
)
}
NatsAuth
::
Token
(
_
token
)
=>
write!
(
f
,
"Token(<redacted>)"
),
NatsAuth
::
NKey
(
_
nkey
)
=>
write!
(
f
,
"NKey(<redacted>)"
),
NatsAuth
::
CredentialsFile
(
path
)
=>
write!
(
f
,
"CredentialsFile({:?})"
,
path
),
}
}
}
impl
Default
for
NatsAuth
{
fn
default
()
->
Self
{
if
let
(
Ok
(
username
),
Ok
(
password
))
=
(
std
::
env
::
var
(
"NATS_AUTH_USERNAME"
),
std
::
env
::
var
(
"NATS_AUTH_PASSWORD"
),
)
{
return
NatsAuth
::
UserPass
(
username
,
password
);
}
if
let
Ok
(
token
)
=
std
::
env
::
var
(
"NATS_AUTH_TOKEN"
)
{
return
NatsAuth
::
Token
(
token
);
}
if
let
Ok
(
nkey
)
=
std
::
env
::
var
(
"NATS_AUTH_NKEY"
)
{
return
NatsAuth
::
NKey
(
nkey
);
}
if
let
Ok
(
path
)
=
std
::
env
::
var
(
"NATS_AUTH_CREDENTIALS_FILE"
)
{
return
NatsAuth
::
CredentialsFile
(
PathBuf
::
from
(
path
));
}
NatsAuth
::
UserPass
(
"user"
.to_string
(),
"user"
.to_string
())
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
figment
::
Jail
;
#[test]
fn
test_client_options_builder
()
{
Jail
::
expect_with
(|
_
jail
|
{
let
opts
=
ClientOptions
::
builder
()
.build
();
assert
!
(
opts
.is_ok
());
Ok
(())
});
Jail
::
expect_with
(|
jail
|
{
jail
.set_env
(
"NATS_SERVER"
,
"nats://localhost:5222"
);
jail
.set_env
(
"NATS_AUTH_USERNAME"
,
"user"
);
jail
.set_env
(
"NATS_AUTH_PASSWORD"
,
"pass"
);
let
opts
=
ClientOptions
::
builder
()
.build
();
assert
!
(
opts
.is_ok
());
let
opts
=
opts
.unwrap
();
assert_eq!
(
opts
.server
,
"nats://localhost:5222"
);
assert_eq!
(
opts
.auth
,
NatsAuth
::
UserPass
(
"user"
.to_string
(),
"pass"
.to_string
())
);
Ok
(())
});
Jail
::
expect_with
(|
jail
|
{
jail
.set_env
(
"NATS_SERVER"
,
"nats://localhost:5222"
);
jail
.set_env
(
"NATS_AUTH_USERNAME"
,
"user"
);
jail
.set_env
(
"NATS_AUTH_PASSWORD"
,
"pass"
);
let
opts
=
ClientOptions
::
builder
()
.server
(
"nats://localhost:6222"
)
.auth
(
NatsAuth
::
Token
(
"token"
.to_string
()))
.build
();
assert
!
(
opts
.is_ok
());
let
opts
=
opts
.unwrap
();
assert_eq!
(
opts
.server
,
"nats://localhost:6222"
);
assert_eq!
(
opts
.auth
,
NatsAuth
::
Token
(
"token"
.to_string
()));
Ok
(())
});
}
// const TEST_STREAM: &str = "test_async_nats_stream";
// #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
// struct Request {
// id: String,
// }
// async fn nats_client() -> Result<Client> {
// Client::builder()
// .server("nats://localhost:4222")
// .username("user")
// .password("user")
// .build()
// .await
// }
// #[tokio::test]
// async fn test_list_streams() {
// let client = match nats_client().await.ok() {
// Some(client) => client,
// None => {
// println!("Failed to create client; skipping nats tests");
// return;
// }
// };
// let streams = client.list_streams().await.expect("failed to list streams");
// for stream in streams {
// let info = client
// .stream_info(&stream)
// .await
// .expect("failed to get stream info");
// assert_eq!(info.messages, 0, "stream {} not empty", stream);
// }
// }
// #[tokio::test]
// async fn test_workq_pull_and_response_stream() {
// let ns: Namespace = TEST_STREAM.try_into().unwrap();
// let _client = match nats_client().await.ok() {
// Some(client) => client,
// None => {
// println!("Failed to create client; skipping nats tests");
// return;
// }
// };
// let client = Client::builder()
// .server("nats://localhost:4222")
// .username("user")
// .password("user")
// .build()
// .await
// .expect("failed to create client");
// let _streams = client.list_streams().await.expect("failed to list streams");
// // assert!(!streams.contains(&TEST_STREAM.to_string()));
// let _stream = client
// .get_or_create_work_queue_stream(&ns)
// .await
// .expect("failed to create stream");
// let model_name: Slug = "foo".try_into().unwrap();
// let request_id = "bar";
// let request = Request {
// id: request_id.to_string(),
// };
// let request_payload = serde_json::to_vec(&request).expect("failed to serialize request");
// // let request = CompletionRequest {
// // prompt: CompletionContext::from_prompt("deep learning is".to_string()).into(),
// // stop_conditions: None,
// // sampling_options: None,
// // };
// // remove work queue if it exists
// client
// .remove_work_queue(&ns, &model_name)
// .await
// .expect("remove work queue does not fail if queue does not exist");
// // get the count of the work queues
// let initial_work_queue_count = client
// .list_work_queues(&ns)
// .await
// .expect("failed to list work queues")
// .len();
// // create work queue
// let workq = client
// .get_or_create_work_queue(&ns, &model_name)
// .await
// .expect("failed to get work queue");
// // new work queue count
// let work_queue_count = client
// .list_work_queues(&ns)
// .await
// .expect("failed to list work queues")
// .len();
// assert_eq!(initial_work_queue_count, work_queue_count - 1);
// client
// .enqueue(&ns, &model_name, request_payload.into())
// .await
// .expect("failed to enqueue completion request");
// let mut messages = workq
// .pull(1, std::time::Duration::from_secs(1))
// .await
// .expect("failed to pull messages from work queue");
// assert_eq!(1, messages.len());
// let msg = messages.pop().expect("no message received");
// msg.ack().await.expect("failed to ack");
// let request: Request =
// serde_json::from_slice(&msg.payload).expect("failed to deserialize message");
// assert_eq!(request.id, request_id);
// // clean up and delete nats work queue and stream
// client
// .remove_work_queue(&ns, &model_name)
// .await
// .expect("failed to remove work queue");
// // client
// // .delete_stream(TEST_STREAM)
// // .await
// // .expect("failed to delete stream");
// }
}
// let frontend_client = client.frontend_client("test".to_string());
// // the represents the frontend response subscription
// let mut frontend_sub = frontend_client
// .subscribe()
// .await
// .expect("failed to subscribe");
// let backend_client = client.backend_client("test".to_string());
// let mut backend_sub = backend_client
// .subscribe()
// .await
// .expect("failed to subscribe");
// let msg = messages[0].clone();
// let req = serde_json::from_slice::<CompletionRequest>(&msg.payload)
// .expect("failed to deserialize message");
// msg.ack().await.expect("failed to ack");
// assert_eq!(req.prompt, request.prompt);
// // ping pong message between backend and frontend
// // backend publishes to frontend
// backend_client
// .publish(&MessageKind::Initialize(Prologue {
// formatted_prompt: None,
// input_token_ids: None,
// }))
// .await
// .expect("failed to publish");
// // frontend receives initialize message
// let msg = frontend_sub.next().await.expect("msg not received");
// let msg = serde_json::from_slice::<MessageKind>(&msg.payload)
// .expect("failed to deserialize message");
// match msg {
// MessageKind::Initialize(_) => {}
// _ => panic!("unexpected message"),
// }
// // frontend publishes to backend
// frontend_client
// .publish(&MessageKind::Finalize(Epilogue {}))
// .await
// .expect("failed to publish");
// // backend receives finalize message
// let msg = backend_sub.next().await.expect("msg not received");
// let msg = serde_json::from_slice::<MessageKind>(&msg.payload)
// .expect("failed to deserialize message");
// match &msg {
// MessageKind::Finalize(_) => {}
// _ => panic!("unexpected message"),
// }
// // delete the work queue
// client
// .remove_work_queue(model_name, TEST_STREAM)
// .await
// .expect("failed to remove work queue");
// // new work queue count
// let work_queue_count = client
// .list_work_queues(TEST_STREAM)
// .await
// .expect("failed to list work queues")
// .len();
// // compare against the initial work queue count
// assert_eq!(initial_work_queue_count, work_queue_count);
// }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// async fn connect(config: Arc<Config>) -> Result<NatsClient> {
// let client = ClientOptions::builder()
// .server(config.nats_address.clone())
// .build()
// .await
// .context("Creating NATS Client")?;
// Ok(client)
// }
// async fn create_service(
// nats: NatsClient,
// config: Arc<Config>,
// observer: ServiceObserver,
// ) -> Result<NatsService> {
// let service = nats
// .client()
// .service_builder()
// .description(config.service_description.as_str())
// .stats_handler(move |_name, _stats| {
// let stats = InstanceStats {
// stage: observer.stage(),
// };
// serde_json::to_value(&stats).unwrap()
// })
// .start(
// config.service_name.as_str(),
// config.service_version.as_str(),
// )
// .await
// .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
// Ok(service)
// }
// async fn create_endpoint(
// endpoint_name: impl Into<String>,
// service: &NatsService,
// ) -> Result<Endpoint> {
// let info = service.info().await;
// let group_name = format!("{}-{}", info.name, info.id);
// let group = service.group(group_name);
// let endpoint = group
// .endpoint(endpoint_name.into())
// .await
// .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
// Ok(endpoint)
// }
// async fn shutdown_endpoint_handler(
// controller: ServiceController,
// endpoint: Endpoint,
// ) -> Result<()> {
// let mut endpoint = endpoint;
// // note: this is a child cancellation token, canceling it will not cancel the parent
// // but the parent will cancel the child -- we only use this to observe if another
// // controller has cancelled the service
// let cancellation_token = controller.cancel_token();
// loop {
// let req = tokio::select! {
// _ = cancellation_token.cancelled() => {
// // log::trace!(worker_id, "Shutting down service {}", self.endpoint.name);
// return Ok(());
// }
// // await on service request
// req = endpoint.next() => {
// req
// }
// };
// if let Some(req) = req {
// let response = "DONE".to_string();
// if let Err(e) = req.respond(Ok(response.into())).await {
// log::warn!("Failed to respond to the shutdown request: {:?}", e);
// }
// controller.set_stage(ServiceStage::ShuttingDown);
// }
// }
// }
// #[derive(Debug, Clone, Builder)]
// pub struct Config {
// /// The NATS server address
// #[builder(default = "String::from(\"nats://localhost:4222\")")]
// pub nats_address: String,
// #[builder(setter(into), default = "String::from(SERVICE_NAME)")]
// pub service_name: String,
// #[builder(setter(into), default = "String::from(SERVICE_VERSION)")]
// pub service_version: String,
// #[builder(setter(into), default = "String::from(SERVICE_DESCRIPTION)")]
// pub service_description: String,
// }
// impl Config {
// pub fn new() -> Result<Config> {
// Ok(ConfigBuilder::default().build()?)
// }
// /// Create a new [`ConfigBuilder`]
// pub fn builder() -> ConfigBuilder {
// ConfigBuilder::default()
// }
// }
// // todo: move to icp - transports
// #[derive(Clone, Debug)]
// pub struct NatsClient {
// client: Client,
// js_ctx: jetstream::Context,
// }
// impl NatsClient {
// pub fn client(&self) -> &Client {
// &self.client
// }
// pub fn jetstream(&self) -> &jetstream::Context {
// &self.js_ctx
// }
// pub fn service_builder(&self) -> NatsServiceBuilder {
// self.client.service_builder()
// }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = tokio::time::Instant::now();
// while let Ok(Some(message)) = tokio::time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// }
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct ServiceInfo {
// pub name: String,
// pub id: String,
// pub version: String,
// pub started: String,
// pub endpoints: Vec<EndpointInfo>,
// }
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct EndpointInfo {
// pub name: String,
// pub subject: String,
// pub data: serde_json::Value,
// }
// impl EndpointInfo {
// pub fn get<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
// serde_json::from_value(self.data.clone()).map_err(Into::into)
// }
// }
// #[derive(Clone, Debug, Builder)]
// #[builder(build_fn(private, name = "build_internal"))]
// pub struct ClientOptions {
// #[builder(setter(into))]
// server: String,
// #[builder(setter(into, strip_option), default)]
// username: Option<String>,
// #[builder(setter(into, strip_option), default)]
// password: Option<String>,
// }
// #[allow(dead_code)]
// impl ClientOptions {
// pub fn builder() -> ClientOptionsBuilder {
// ClientOptionsBuilder::default()
// }
// }
// impl ClientOptionsBuilder {
// pub async fn build(&self) -> Result<NatsClient> {
// let opts = self.build_internal()?;
// // Create an unauthenticated connection to NATS.
// let client = async_nats::ConnectOptions::new();
// let client = if let (Some(username), Some(password)) = (opts.username, opts.password) {
// client.user_and_password(username, password)
// } else {
// client
// };
// let client = client.connect(&opts.server).await?;
// let js_ctx = jetstream::new(client.clone());
// Ok(NatsClient { client, js_ctx })
// }
// }
runtime/rust/src/transports/nats/slug.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use
serde
::
de
::{
self
,
Deserializer
,
Visitor
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
fmt
;
const
REPLACEMENT_CHAR
:
char
=
'_'
;
/// URL and NATS friendly string.
/// Only a-z, 0-9, - and _.
#[derive(Serialize,
Clone,
Debug,
Eq,
PartialEq)]
pub
struct
Slug
(
String
);
impl
Slug
{
fn
new
(
s
:
String
)
->
Slug
{
// remove any leading REPLACEMENT_CHAR
let
s
=
s
.trim_start_matches
(
REPLACEMENT_CHAR
)
.to_string
();
Slug
(
s
)
}
/// Create [`Slug`] from a string.
pub
fn
from_string
(
s
:
impl
AsRef
<
str
>
)
->
Slug
{
Slug
::
slugify_unique
(
s
.as_ref
())
}
// /// Turn the string into a valid slug, replacing any not-web-or-nats-safe characters with '-'
// fn slugify(s: &str) -> Slug {
// let out = s
// .to_lowercase()
// .chars()
// .map(|c| {
// let is_valid = c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_';
// if is_valid {
// c
// } else {
// REPLACEMENT_CHAR
// }
// })
// .collect::<String>();
// Slug::new(out)
// }
/// Like slugify but also add a four byte hash on the end, in case two different strings slug
/// to the same thing.
fn
slugify_unique
(
s
:
&
str
)
->
Slug
{
let
out
=
s
.to_lowercase
()
.chars
()
.map
(|
c
|
{
let
is_valid
=
c
.is_ascii_lowercase
()
||
c
.is_ascii_digit
()
||
c
==
'-'
||
c
==
'_'
;
if
is_valid
{
c
}
else
{
REPLACEMENT_CHAR
}
})
.collect
::
<
String
>
();
let
hash
=
blake3
::
hash
(
s
.as_bytes
())
.to_string
();
let
out
=
format!
(
"{out}-{}"
,
&
hash
[(
hash
.len
()
-
8
)
..
]);
Slug
::
new
(
out
)
}
}
impl
fmt
::
Display
for
Slug
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
)
->
fmt
::
Result
{
write!
(
f
,
"{}"
,
self
.0
)
}
}
#[derive(Debug)]
pub
struct
InvalidSlugError
(
char
);
impl
fmt
::
Display
for
InvalidSlugError
{
fn
fmt
(
&
self
,
f
:
&
mut
fmt
::
Formatter
)
->
fmt
::
Result
{
write!
(
f
,
"Invalid char '{}'. String can only contain a-z, 0-9, - and _."
,
self
.0
)
}
}
impl
std
::
error
::
Error
for
InvalidSlugError
{}
impl
TryFrom
<&
str
>
for
Slug
{
type
Error
=
InvalidSlugError
;
fn
try_from
(
s
:
&
str
)
->
Result
<
Self
,
Self
::
Error
>
{
s
.to_string
()
.try_into
()
}
}
impl
TryFrom
<
String
>
for
Slug
{
type
Error
=
InvalidSlugError
;
fn
try_from
(
s
:
String
)
->
Result
<
Self
,
Self
::
Error
>
{
let
is_invalid
=
|
c
:
&
char
|
!
c
.is_ascii_lowercase
()
&&
!
c
.is_ascii_digit
()
&&
*
c
!=
'-'
&&
*
c
!=
'_'
;
match
s
.chars
()
.find
(
is_invalid
)
{
None
=>
Ok
(
Slug
(
s
)),
Some
(
c
)
=>
Err
(
InvalidSlugError
(
c
)),
}
}
}
impl
<
'de
>
Deserialize
<
'de
>
for
Slug
{
fn
deserialize
<
D
>
(
deserializer
:
D
)
->
Result
<
Self
,
D
::
Error
>
where
D
:
Deserializer
<
'de
>
,
{
struct
SlugVisitor
;
impl
Visitor
<
'_
>
for
SlugVisitor
{
type
Value
=
Slug
;
fn
expecting
(
&
self
,
formatter
:
&
mut
fmt
::
Formatter
)
->
fmt
::
Result
{
formatter
.write_str
(
"a valid slug string containing only characters a-z, 0-9, - and _."
)
}
fn
visit_str
<
E
>
(
self
,
v
:
&
str
)
->
Result
<
Self
::
Value
,
E
>
where
E
:
de
::
Error
,
{
Slug
::
try_from
(
v
)
.map_err
(
de
::
Error
::
custom
)
}
fn
visit_string
<
E
>
(
self
,
v
:
String
)
->
Result
<
Self
::
Value
,
E
>
where
E
:
de
::
Error
,
{
Slug
::
try_from
(
v
.as_ref
())
.map_err
(
de
::
Error
::
custom
)
}
}
deserializer
.deserialize_string
(
SlugVisitor
)
}
}
impl
AsRef
<
str
>
for
Slug
{
fn
as_ref
(
&
self
)
->
&
str
{
&
self
.0
}
}
impl
PartialEq
<
str
>
for
Slug
{
fn
eq
(
&
self
,
other
:
&
str
)
->
bool
{
self
.0
==
other
}
}
runtime/rust/src/transports/tcp.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
pub
use
crate
::
pipeline
::
network
::
tcp
::{
client
,
server
};
runtime/rust/src/worker.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
//! The [Worker] class is a convenience wrapper around the construction of the [Runtime]
//! and execution of the users application.
//!
//! In the future, the [Worker] should probably be moved to a procedural macro similar
//! to the `#[tokio::main]` attribute, where we might annotate an async main function with
//! #[triton::main] or similar.
//!
//! The [Worker::execute] method is designed to be called once from main and will block
//! the calling thread until the application completes or is canceled. The method initialized
//! the signal handler used to trap `SIGINT` and `SIGTERM` signals and trigger a graceful shutdown.
//!
//! On termination, the user application is given a graceful shutdown period of controlled by
//! the [TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT] environment variable. If the application does not
//! shutdown in time, the worker will terminate the application with an exit code of 911.
//!
//! The default values of `TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT` differ between the development
//! and release builds. In development, the default is [DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG] and
//! in release, the default is [DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_RELEASE].
use
super
::{
error
,
log
,
CancellationToken
,
Result
,
Runtime
,
RuntimeConfig
};
use
futures
::
Future
;
use
once_cell
::
sync
::
OnceCell
;
use
std
::{
sync
::
Mutex
,
time
::
Duration
};
use
tokio
::{
signal
,
task
::
JoinHandle
};
static
RT
:
OnceCell
<
tokio
::
runtime
::
Runtime
>
=
OnceCell
::
new
();
static
INIT
:
OnceCell
<
Mutex
<
Option
<
tokio
::
task
::
JoinHandle
<
Result
<
()
>>>>>
=
OnceCell
::
new
();
const
SHUTDOWN_MESSAGE
:
&
str
=
"Application received shutdown signal; attempting to gracefully shutdown"
;
const
SHUTDOWN_TIMEOUT_MESSAGE
:
&
str
=
"Use TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT to control the graceful shutdown timeout"
;
/// Environment variable to control the graceful shutdown timeout
pub
const
TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT
:
&
str
=
"TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT"
;
/// Default graceful shutdown timeout in seconds in debug mode
pub
const
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG
:
u64
=
5
;
/// Default graceful shutdown timeout in seconds in release mode
pub
const
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_RELEASE
:
u64
=
30
;
pub
struct
Worker
{
runtime
:
Runtime
,
}
impl
Worker
{
/// Create a new [`Worker`] instance from [`RuntimeConfig`] settings which is sourced from the environment
pub
fn
from_settings
()
->
Result
<
Worker
>
{
let
config
=
RuntimeConfig
::
from_settings
()
?
;
Worker
::
from_config
(
config
)
}
/// Create a new [`Worker`] instance from a provided [`RuntimeConfig`]
pub
fn
from_config
(
config
:
RuntimeConfig
)
->
Result
<
Worker
>
{
// if the runtime is already initialized, return an error
if
RT
.get
()
.is_some
()
{
return
Err
(
error!
(
"Worker already initialized"
));
}
// create a new runtime and insert it into the OnceCell
// there is still a potential race-condition here, two threads cou have passed the first check
// but only one will succeed in inserting the runtime
let
rt
=
RT
.try_insert
(
config
.create_runtime
()
?
)
.map_err
(|
_
|
{
error!
(
"Failed to create worker; Only a single Worker should ever be created"
)
})
?
;
let
runtime
=
Runtime
::
from_handle
(
rt
.handle
()
.clone
())
?
;
Ok
(
Worker
{
runtime
})
}
pub
fn
tokio_runtime
(
&
self
)
->
Result
<&
'static
tokio
::
runtime
::
Runtime
>
{
RT
.get
()
.ok_or_else
(||
error!
(
"Worker not initialized"
))
}
pub
fn
runtime
(
&
self
)
->
&
Runtime
{
&
self
.runtime
}
/// Executes the provided application/closure on the [`Runtime`].
/// This is designed to be called once from main and will block the calling thread until the application completes.
pub
fn
execute
<
F
,
Fut
>
(
self
,
f
:
F
)
->
Result
<
()
>
where
F
:
FnOnce
(
Runtime
)
->
Fut
+
Send
+
'static
,
Fut
:
Future
<
Output
=
Result
<
()
>>
+
Send
+
'static
,
{
let
runtime
=
self
.runtime
;
let
primary
=
runtime
.primary
();
let
secondary
=
runtime
.secondary
.clone
();
let
timeout
=
std
::
env
::
var
(
TRITON_WORKER_GRACEFUL_SHUTDOWN_TIMEOUT
)
.ok
()
.and_then
(|
s
|
s
.parse
::
<
u64
>
()
.ok
())
.unwrap_or
({
if
cfg!
(
debug_assertions
)
{
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_DEBUG
}
else
{
DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_RELEASE
}
});
INIT
.set
(
Mutex
::
new
(
Some
(
secondary
.spawn
(
async
move
{
// start signal handler
tokio
::
spawn
(
signal_handler
(
runtime
.cancellation_token
.clone
()));
let
cancel_token
=
runtime
.child_token
();
let
(
mut
app_tx
,
app_rx
)
=
tokio
::
sync
::
oneshot
::
channel
::
<
()
>
();
// spawn a task to run the application
let
task
:
JoinHandle
<
Result
<
()
>>
=
primary
.spawn
(
async
move
{
let
_
rx
=
app_rx
;
f
(
runtime
)
.await
});
tokio
::
select!
{
_
=
cancel_token
.cancelled
()
=>
{
eprintln!
(
"{}"
,
SHUTDOWN_MESSAGE
);
eprintln!
(
"{} {} seconds"
,
SHUTDOWN_TIMEOUT_MESSAGE
,
timeout
);
}
_
=
app_tx
.closed
()
=>
{
}
};
let
result
=
tokio
::
select!
{
result
=
task
=>
{
result
}
_
=
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_secs
(
timeout
))
=>
{
eprintln!
(
"Application did not shutdown in time; terminating"
);
std
::
process
::
exit
(
911
);
}
}
?
;
match
&
result
{
Ok
(
_
)
=>
{
log
::
info!
(
"Application shutdown successfully"
);
}
Err
(
e
)
=>
{
log
::
error!
(
"Application shutdown with error: {:?}"
,
e
);
}
}
result
}))))
.map_err
(|
e
|
error!
(
"Failed to spawn application task: {:?}"
,
e
))
?
;
let
task
=
INIT
.get
()
.expect
(
"Application task not initialized"
)
.lock
()
.unwrap
()
.take
()
.expect
(
"Application initialized; but another thread is awaiting it; Worker.execute() can only be called once"
);
secondary
.block_on
(
task
)
?
}
}
/// Catch signals and trigger a shutdown
async
fn
signal_handler
(
cancel_token
:
CancellationToken
)
->
Result
<
()
>
{
let
ctrl_c
=
async
{
signal
::
ctrl_c
()
.await
?
;
anyhow
::
Ok
(())
};
let
sigterm
=
async
{
signal
::
unix
::
signal
(
signal
::
unix
::
SignalKind
::
terminate
())
?
.recv
()
.await
;
anyhow
::
Ok
(())
};
tokio
::
select!
{
_
=
ctrl_c
=>
{
tracing
::
info!
(
"Ctrl+C received, starting graceful shutdown"
);
},
_
=
sigterm
=>
{
tracing
::
info!
(
"SIGTERM received, starting graceful shutdown"
);
},
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
info!
(
"CancellationToken triggered; shutting down"
);
},
}
// trigger a shutdown
cancel_token
.cancel
();
Ok
(())
}
runtime/rust/tests/common/engines.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
#![allow(dead_code)]
use
std
::{
future
::
Future
,
pin
::
Pin
,
sync
::
Arc
};
use
async_trait
::
async_trait
;
use
futures
::
Stream
;
use
tokio
::
sync
::
mpsc
;
use
triton_distributed
::
engine
::{
AsyncEngine
,
AsyncEngineContext
,
AsyncEngineContextProvider
,
AsyncEngineStream
,
Data
as
DataType
,
Engine
,
EngineStream
,
};
use
triton_distributed
::
pipeline
::{
context
::{
Context
,
StreamContext
},
Error
,
ManyOut
,
SingleIn
,
};
pub
type
AsyncFn
<
T
,
U
>
=
dyn
Fn
(
T
)
->
Pin
<
Box
<
dyn
Future
<
Output
=
U
>
+
Send
>>
+
Send
+
Sync
;
#[derive(Clone)]
// Define a struct that holds an async closure
pub
struct
AsyncProcessor
<
T
,
U
>
{
func
:
Arc
<
AsyncFn
<
T
,
U
>>
,
}
impl
<
T
,
U
>
AsyncProcessor
<
T
,
U
>
where
T
:
Send
+
'static
,
U
:
Send
+
'static
,
{
// Define a `new` method that captures the already pinned async block
pub
fn
new
<
F
,
Fut
>
(
f
:
F
)
->
Self
where
F
:
Fn
(
T
)
->
Fut
+
Send
+
Sync
+
'static
,
Fut
:
Future
<
Output
=
U
>
+
Send
+
'static
,
{
// Wrap the closure in Arc and Box it for internal management
AsyncProcessor
{
func
:
Arc
::
new
(
move
|
input
:
T
|
Box
::
pin
(
f
(
input
))),
}
}
// Method to execute the captured async function
pub
async
fn
process
(
&
self
,
input
:
T
)
->
U
{
(
self
.func
)(
input
)
.await
}
}
#[derive(Debug,
Clone)]
pub
struct
ResponseSource
<
T
:
Send
+
Sync
+
'static
>
{
tx
:
mpsc
::
Sender
<
T
>
,
ctx
:
StreamContext
,
}
impl
<
T
:
Send
+
Sync
+
'static
>
ResponseSource
<
T
>
{
fn
new
(
tx
:
mpsc
::
Sender
<
T
>
,
ctx
:
StreamContext
)
->
Self
{
ResponseSource
{
tx
,
ctx
}
}
/// Emit a response to the stream
pub
async
fn
emit
(
&
self
,
data
:
T
)
->
Result
<
(),
()
>
{
self
.tx
.send
(
data
)
.await
.map_err
(|
_
|
())
}
/// Check if a stop has been requested
pub
fn
stop_requested
(
&
self
)
->
bool
{
self
.ctx
.is_stopped
()
}
/// Yield control until a stop is requested
/// This is useful in a tokio::select! block
pub
async
fn
stopped
(
&
self
)
{
self
.ctx
.stopped
()
.await
;
}
}
pub
type
AsyncGenerator
<
Req
,
Resp
>
=
AsyncProcessor
<
(
Req
,
ResponseSource
<
Resp
>
),
()
>
;
pub
struct
ReceiverStream
<
Resp
:
DataType
>
{
receiver
:
tokio
::
sync
::
mpsc
::
Receiver
<
Resp
>
,
context
:
Arc
<
dyn
AsyncEngineContext
>
,
}
impl
<
Resp
:
DataType
>
ReceiverStream
<
Resp
>
{
pub
fn
new
(
receiver
:
tokio
::
sync
::
mpsc
::
Receiver
<
Resp
>
,
context
:
Arc
<
dyn
AsyncEngineContext
>
,
)
->
Self
{
Self
{
receiver
,
context
}
}
}
impl
<
Resp
:
DataType
>
Stream
for
ReceiverStream
<
Resp
>
{
type
Item
=
Resp
;
fn
poll_next
(
mut
self
:
Pin
<&
mut
Self
>
,
cx
:
&
mut
std
::
task
::
Context
<
'_
>
,
)
->
std
::
task
::
Poll
<
Option
<
Self
::
Item
>>
{
// if self.context.stop_issued() {
// return std::task::Poll::Ready(None);
// }
// Pinning the receiver to safely call poll_recv
Pin
::
new
(
&
mut
self
.receiver
)
.poll_recv
(
cx
)
}
}
impl
<
Resp
:
DataType
>
std
::
fmt
::
Debug
for
ReceiverStream
<
Resp
>
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"ReceiverStream"
)
.field
(
"context"
,
&
self
.context
)
.finish
()
}
}
impl
<
Resp
:
DataType
>
AsyncEngineStream
<
Resp
>
for
ReceiverStream
<
Resp
>
{}
impl
<
Resp
:
DataType
>
AsyncEngineContextProvider
for
ReceiverStream
<
Resp
>
{
fn
context
(
&
self
)
->
Arc
<
dyn
AsyncEngineContext
>
{
self
.context
.clone
()
}
}
pub
struct
LlmdbaEngine
<
Req
:
DataType
,
Resp
:
DataType
>
{
lambda
:
Arc
<
AsyncGenerator
<
Req
,
Resp
>>
,
}
impl
<
Req
:
DataType
,
Resp
:
DataType
>
LlmdbaEngine
<
Req
,
Resp
>
{
fn
new
(
lambda
:
AsyncGenerator
<
Req
,
Resp
>
)
->
Self
{
LlmdbaEngine
{
lambda
:
Arc
::
new
(
lambda
),
}
}
pub
fn
from_generator
(
generator
:
AsyncGenerator
<
Req
,
Resp
>
,
)
->
Engine
<
SingleIn
<
Req
>
,
ManyOut
<
Resp
>
,
Error
>
{
Arc
::
new
(
LlmdbaEngine
::
new
(
generator
))
}
}
#[async_trait]
impl
<
Req
:
DataType
,
Resp
:
DataType
>
AsyncEngine
<
SingleIn
<
Req
>
,
ManyOut
<
Resp
>
,
Error
>
for
LlmdbaEngine
<
Req
,
Resp
>
{
async
fn
generate
(
&
self
,
request
:
Context
<
Req
>
)
->
Result
<
EngineStream
<
Resp
>
,
Error
>
{
let
(
tx
,
rx
)
=
mpsc
::
channel
::
<
Resp
>
(
1
);
let
(
req
,
ctx
)
=
request
.transfer
(());
let
ctx
:
StreamContext
=
ctx
.into
();
let
s
=
ResponseSource
::
new
(
tx
,
ctx
.clone
());
let
lambda
=
self
.lambda
.clone
();
let
_
handle
=
tokio
::
spawn
(
async
move
{
lambda
.process
((
req
,
s
))
.await
});
let
ctx
=
Arc
::
new
(
ctx
);
let
stream
=
ReceiverStream
::
<
Resp
>
::
new
(
rx
,
ctx
);
let
stream
=
Box
::
pin
(
stream
);
Ok
(
stream
)
}
}
#[cfg(test)]
mod
tests
{
use
futures
::
StreamExt
;
use
super
::
*
;
#[tokio::test]
async
fn
test_async_processor
()
{
let
processor
=
AsyncProcessor
::
new
(
move
|
x
:
i32
|
{
async
move
{
// Simulate some async work
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
100
))
.await
;
format!
(
"Processed value: {}"
,
x
)
}
});
// Use the processor to run the async closure
let
result
=
processor
.process
(
42
)
.await
;
println!
(
"{}"
,
result
);
// Output: Processed value: 42
let
result2
=
processor
.process
(
100
)
.await
;
println!
(
"{}"
,
result2
);
// Output: Processed value: 100
}
#[tokio::test]
async
fn
test_generator
()
{
let
generator
=
AsyncGenerator
::
<
String
,
String
>
::
new
(|(
req
,
stream
)|
async
move
{
let
chars
=
req
.chars
()
.collect
::
<
Vec
<
char
>>
();
for
c
in
chars
{
match
stream
.emit
(
c
.to_string
())
.await
{
Ok
(
_
)
=>
{}
Err
(
_
)
=>
break
,
}
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
100
))
.await
;
}
});
let
engine
=
LlmdbaEngine
::
new
(
generator
);
let
mut
stream
=
engine
.generate
(
"test"
.to_string
()
.into
())
.await
.unwrap
();
let
mut
counter
=
0
;
while
let
Some
(
_
output
)
=
stream
.next
()
.await
{
counter
+=
1
;
}
assert_eq!
(
counter
,
4
);
}
}
runtime/rust/tests/common/mock.rs
0 → 100644
View file @
5ed8c1c0
/*
* Copyright 2024-2025 NVIDIA CORPORATION & AFFILIATES
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
OnceLock
};
use
async_trait
::
async_trait
;
use
futures
::
StreamExt
;
use
serde
::{
Deserialize
,
Serialize
};
use
tokio
::
sync
::
mpsc
;
use
triton_distributed
::
engine
::{
AsyncEngine
,
AsyncEngineContext
,
Data
,
ResponseStream
};
use
triton_distributed
::
pipeline
::{
context
::{
Context
,
StreamContext
},
Error
,
ManyOut
,
PipelineError
,
PipelineIO
,
SegmentSource
,
SingleIn
,
};
#[allow(dead_code)]
#[derive(Debug,
Clone)]
pub
enum
LatencyModel
{
NoDelay
,
ConstantDelayInNanos
(
u64
),
NormalDistributionInNanos
(
u64
,
u64
),
}
#[allow(dead_code)]
#[derive(Debug,
Clone)]
pub
struct
MockNetworkOptions
{
request_latency
:
LatencyModel
,
response_latency
:
LatencyModel
,
}
impl
Default
for
MockNetworkOptions
{
fn
default
()
->
Self
{
Self
{
request_latency
:
LatencyModel
::
NoDelay
,
response_latency
:
LatencyModel
::
NoDelay
,
}
}
}
#[derive(Debug,
Clone)]
struct
ControlPlaneRequest
{
id
:
String
,
request
:
Vec
<
u8
>
,
// convert this into an interface where it describes the worker address
// and how to communicate with the worker
resp_tx
:
mpsc
::
Sender
<
DataPlaneMessage
>
,
}
enum
MockNetworkControlEvents
{
ControlPlaneRequest
(
ControlPlaneRequest
),
Cancel
(
String
),
}
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
enum
MockNetworkDataPlaneHeaders
{
Handshake
(
Handshake
),
Error
(
String
),
// tells the subscriber that the stream has ended
// not all transports will be sender side closable, therefore,
// we need a way to signal the end of the stream
//
// note: for transports like nats where the subscriber could
// be left dangling, we will also want to have a keep alive
// and a timeout mechanism
Sentinel
,
// heart beat / keep-alive signal to maintain the connection
HeartBeat
,
}
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[serde(rename_all
=
"snake_case"
)]
enum
Status
{
Ok
,
Error
(
String
),
}
// for transports that support headers, we will use headers for events and the body for the bytes
// for transports like tcp, we may send them as two separate messages on the same socket or as a single
// compound message like the [`DataEnvelope`] object below
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
struct
Handshake
{
request_id
:
String
,
worker_id
:
Option
<
String
>
,
status
:
Status
,
}
struct
DataPlaneMessage
{
pub
headers
:
Option
<
MockNetworkDataPlaneHeaders
>
,
pub
body
:
Vec
<
u8
>
,
}
/// This is an example transport that will inject latency into the response stream.
/// This mimics a network transport that has a delay in the response.
pub
struct
MockNetworkTransport
<
T
:
PipelineIO
,
U
:
PipelineIO
>
{
req
:
std
::
marker
::
PhantomData
<
T
>
,
resp
:
std
::
marker
::
PhantomData
<
U
>
,
}
impl
<
Req
:
PipelineIO
,
Resp
:
PipelineIO
>
MockNetworkTransport
<
Req
,
Resp
>
{
pub
fn
new_egress_ingress
(
options
:
MockNetworkOptions
,
)
->
(
Arc
<
MockNetworkEgress
<
Req
,
Resp
>>
,
MockNetworkIngress
<
Req
,
Resp
>
,
)
{
let
(
ctrl_tx
,
ctrl_rx
)
=
mpsc
::
channel
::
<
MockNetworkControlEvents
>
(
8
);
// construct the egress/request-sender/response-receiver
let
egress
=
Arc
::
new
(
MockNetworkEgress
::
<
Req
,
Resp
>
::
new
(
options
.clone
(),
ctrl_tx
.clone
(),
));
// construct the ingress/request-receiver/response-sender
let
ingress
=
MockNetworkIngress
::
<
Req
,
Resp
>
::
new
(
options
.clone
(),
ctrl_rx
);
(
egress
,
ingress
)
}
}
#[allow(dead_code)]
pub
struct
MockNetworkEgress
<
Req
:
PipelineIO
,
Resp
:
PipelineIO
>
{
options
:
MockNetworkOptions
,
ctrl_tx
:
mpsc
::
Sender
<
MockNetworkControlEvents
>
,
req
:
std
::
marker
::
PhantomData
<
Req
>
,
resp
:
std
::
marker
::
PhantomData
<
Resp
>
,
}
impl
<
Req
:
PipelineIO
,
Resp
:
PipelineIO
>
MockNetworkEgress
<
Req
,
Resp
>
{
fn
new
(
options
:
MockNetworkOptions
,
ctrl_tx
:
mpsc
::
Sender
<
MockNetworkControlEvents
>
)
->
Self
{
Self
{
options
,
ctrl_tx
,
req
:
std
::
marker
::
PhantomData
,
resp
:
std
::
marker
::
PhantomData
,
}
}
}
#[async_trait]
impl
<
T
:
Data
,
U
:
Data
>
AsyncEngine
<
SingleIn
<
T
>
,
ManyOut
<
U
>
,
Error
>
for
MockNetworkEgress
<
SingleIn
<
T
>
,
ManyOut
<
U
>>
where
T
:
Data
+
Serialize
,
U
:
for
<
'de
>
Deserialize
<
'de
>
+
Data
,
{
async
fn
generate
(
&
self
,
request
:
SingleIn
<
T
>
)
->
Result
<
ManyOut
<
U
>
,
Error
>
{
let
id
=
request
.id
()
.to_string
();
// serialze the request
let
request
=
request
.try_map
(|
req
|
serde_json
::
to_vec
(
&
req
))
?
;
// transfer the request context to a stream context
let
(
data
,
context
)
=
request
.transfer
(());
let
context
=
Arc
::
new
(
StreamContext
::
from
(
context
));
// subscribe to the response stream
// but in this case, we are doing a mock, so we are going to be more explicit
// since we are transferring data over a channel instead of the networ, creating the channel
// is the same as subscribing to the response stream
let
(
data_tx
,
data_rx
)
=
mpsc
::
channel
::
<
DataPlaneMessage
>
(
16
);
let
mut
byte_stream
=
tokio_stream
::
wrappers
::
ReceiverStream
::
new
(
data_rx
);
// prepare the stateful objects that will be used to monitor the response stream
// finish_rx is a oneshot channel that will be used to signal the natural termination of the stream
let
(
finished_tx
,
finished_rx
)
=
tokio
::
sync
::
oneshot
::
channel
::
<
()
>
();
let
stream_monitor
=
ResponseMonitor
{
ctx
:
context
.clone
(),
finish_rx
:
finished_rx
,
};
// create the control plane request
// when this is issued, control is handed off to the control plane and the downstream segment
// sometimes we might include the local server address and port for the response find its way home
// todo(design) this will be part of the generalization error for multiple transport types
let
request
=
ControlPlaneRequest
{
id
,
request
:
data
,
resp_tx
:
data_tx
,
};
// send the request to the control plane
self
.ctrl_tx
.send
(
MockNetworkControlEvents
::
ControlPlaneRequest
(
request
))
.await
.map_err
(|
e
|
PipelineError
::
ControlPlaneRequestError
(
e
.to_string
()))
?
;
// the first message from the remote publisher on the data plane needs to be a handshake message
// the handshake will indicate to what stream the data belongs to and if the remote segment was
// able to process the request.
//
// note: in the case of the mock transport, the handshaking of the request id is not strictly
// because the channel is specific to the request. this is similar to other transports like nats
// where we will subscribe to a response stream on a subject unique to the stream.
match
byte_stream
.next
()
.await
{
Some
(
DataPlaneMessage
{
headers
,
body
})
=>
{
if
!
body
.is_empty
()
{
Err
(
PipelineError
::
ControlPlaneRequestError
(
"Expected an empty body for the handshake message"
.to_string
(),
))
?
;
}
match
headers
{
Some
(
header
)
=>
{
match
header
{
MockNetworkDataPlaneHeaders
::
Handshake
(
handshake
)
=>
{
match
handshake
.status
{
Status
::
Ok
=>
{}
Status
::
Error
(
e
)
=>
{
// todo(metrics): increment metric counter for failed handshakes
Err
(
PipelineError
::
ControlPlaneRequestError
(
format!
(
"remote segment was unable to process request: {}"
,
e
)))
?
;
}
}
}
_
=>
{
Err
(
PipelineError
::
ControlPlaneRequestError
(
format!
(
"Expected a handshake message; got: {:?}"
,
header
)))
?
;
}
}
}
_
=>
{
Err
(
PipelineError
::
ControlPlaneRequestError
(
"Failed to receive properly formatted handshake on data plane"
.to_string
(),
))
?
;
}
}
}
None
=>
{
// todo(metrics): increment metric counter for failed requests
Err
(
PipelineError
::
ControlPlaneRequestError
(
"Failed data plane connection closed before receiving handshake"
.to_string
(),
))
?
;
}
}
let
decoded
=
byte_stream
// .inspect(|_item| {
// // todo(metrics) increment the metrics counter by the number of bytes
// })
.scan
(
Some
(
stream_monitor
),
move
|
_
stream_monitor
,
item
|
{
// we could check the kill state of the context and terminate the stream here
// if our transport needs a heartbeat, trigger a heartbeat here the monitor
if
let
Some
(
headers
)
=
&
item
.headers
{
match
headers
{
MockNetworkDataPlaneHeaders
::
HeartBeat
=>
{
// todo(metrics): increment metric counter for heartbeats
// send a heartbeat to the control plane
// this is a good place to send a heartbeat to the control plane
// to keep the connection alive
}
MockNetworkDataPlaneHeaders
::
Sentinel
=>
{
// todo(metrics): increment metric counter for sentinels
// the stream has ended
// send a sentinel to the control plane
// this is a good place to send a sentinel to the control plane
// to indicate the end of the stream
return
futures
::
future
::
ready
(
None
);
}
_
=>
{}
}
}
futures
::
future
::
ready
(
Some
(
item
))
})
// decode the response
.map
(
move
|
item
|
{
serde_json
::
from_slice
::
<
U
>
(
&
item
.body
)
.expect
(
"failed to deserialize response"
)
});
// cancellation can be tricky and is transport / protocol specific
// in this case, our channel for this is both ordered and 1:1, thus we can
// use that fact to first send the request, then forward any cancellation requests
// this ensures the downstream node should register the context/request id before any
// cancellation requests are sent
// create the cancellation monitor object
let
cancellation_monitor
=
CancellationMonitor
{
ctx
:
context
.clone
(),
ctrl_tx
:
self
.ctrl_tx
.clone
(),
finish_tx
:
finished_tx
,
};
// launch the cancellation monitor task
tokio
::
spawn
(
cancellation_monitor
.execute
());
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
decoded
),
context
))
}
}
/// For our MocNetworkTransport, the Ingress will be the one that will be receiving the requests
/// and pushes back the responses
///
/// As such, the Ingress will be the one that will be responsible for receiving control plane messages.
#[allow(dead_code)]
pub
struct
MockNetworkIngress
<
Req
:
PipelineIO
,
Resp
:
PipelineIO
>
{
options
:
MockNetworkOptions
,
ctrl_rx
:
mpsc
::
Receiver
<
MockNetworkControlEvents
>
,
segment
:
OnceLock
<
Arc
<
SegmentSource
<
Req
,
Resp
>>>
,
}
impl
<
Req
:
PipelineIO
,
Resp
:
PipelineIO
>
MockNetworkIngress
<
Req
,
Resp
>
{
fn
new
(
options
:
MockNetworkOptions
,
ctrl_rx
:
mpsc
::
Receiver
<
MockNetworkControlEvents
>
)
->
Self
{
Self
{
options
,
ctrl_rx
,
segment
:
OnceLock
::
new
(),
}
}
pub
fn
segment
(
&
self
,
segment
:
Arc
<
SegmentSource
<
Req
,
Resp
>>
)
->
Result
<
(),
PipelineError
>
{
self
.segment
.set
(
segment
)
.map_err
(|
_
|
PipelineError
::
EdgeAlreadySet
)
}
}
impl
<
T
:
Data
,
U
:
Data
>
MockNetworkIngress
<
SingleIn
<
T
>
,
ManyOut
<
U
>>
where
T
:
Data
+
for
<
'de
>
Deserialize
<
'de
>
,
U
:
Data
+
Serialize
,
{
pub
async
fn
execute
(
self
)
->
Result
<
(),
PipelineError
>
{
let
mut
state
=
HashMap
::
<
String
,
Arc
<
dyn
AsyncEngineContext
>>
::
new
();
let
worker_id
=
uuid
::
Uuid
::
new_v4
()
.to_string
();
let
mut
ctrl_rx
=
self
.ctrl_rx
;
let
segment
=
self
.segment
.get
()
.expect
(
"segment not set"
)
.clone
();
while
let
Some
(
event
)
=
ctrl_rx
.recv
()
.await
{
match
event
{
MockNetworkControlEvents
::
ControlPlaneRequest
(
req
)
=>
{
// todo(metrics): increment metric counter for bytes received
// todo(metrics): increment metric counter for requests received
let
id
=
req
.id
.clone
();
tracing
::
debug!
(
"[ingress] received request [id: {}]"
,
id
);
// deserialize the request
let
request
=
serde_json
::
from_slice
::
<
T
>
(
&
req
.request
)
.expect
(
"failed to deserialize request"
);
// extend request with context
let
request
=
Context
::
<
T
>
::
with_id
(
request
,
req
.id
.clone
());
// create the response stream
let
response
=
segment
.generate
(
request
)
.await
;
let
handshake
=
match
&
response
{
Ok
(
_
)
=>
Handshake
{
request_id
:
req
.id
,
worker_id
:
Some
(
worker_id
.clone
()),
status
:
Status
::
Ok
,
},
Err
(
e
)
=>
Handshake
{
request_id
:
req
.id
,
worker_id
:
Some
(
worker_id
.clone
()),
status
:
Status
::
Error
(
e
.to_string
()),
},
};
tracing
::
debug!
(
"[ingress] sending handshake [id: {}]: {:?}"
,
id
,
handshake
);
// serialize the handshake
let
handshake
=
DataPlaneMessage
{
headers
:
Some
(
MockNetworkDataPlaneHeaders
::
Handshake
(
handshake
)),
body
:
vec!
[],
};
// send the handshake
req
.resp_tx
.send
(
handshake
)
.await
.expect
(
"failed to send handshake"
);
tracing
::
trace!
(
"[ingress] handshake sent [id: {}]"
,
id
);
if
let
Ok
(
response
)
=
response
{
// spawn a task to process the response stream:
// - serialize each response
// - forward the bytes to the data plane
tracing
::
debug!
(
"[ingress] processing response stream [id: {}]"
,
id
);
tokio
::
spawn
(
async
move
{
let
mut
response
=
response
;
while
let
Some
(
resp
)
=
response
.next
()
.await
{
tracing
::
trace!
(
"[ingress] received response [id: {}]"
,
id
);
let
resp_bytes
=
serde_json
::
to_vec
(
&
resp
)
.expect
(
"failed to serialize response"
);
let
msg
=
DataPlaneMessage
{
headers
:
None
,
body
:
resp_bytes
,
};
// send the response
req
.resp_tx
.send
(
msg
)
.await
.expect
(
"failed to send response"
);
tracing
::
trace!
(
"[ingress] sent response [id: {}]"
,
id
);
}
tracing
::
debug!
(
"response stream completed [id: {}]"
,
id
);
});
}
}
MockNetworkControlEvents
::
Cancel
(
id
)
=>
{
// todo(metrics): increment metric counter for cancelled requests
// todo(metrics): increment metric counter for bytes received
// todo(metrics): increment metric counter for requests received
// cancel the request
if
let
Some
(
tx
)
=
state
.remove
(
&
id
)
{
tx
.stop_generating
();
}
}
}
}
Ok
(())
}
}
// fn create_error_message(id: &str, e: &str) -> Hand {
// format!("Failed to deserialize request [id: {}]: {}", id, e)
// }
/// Object transferred to the Cancellation Monitor Task
///
/// The cancellation monitor task will be responsible for taking action on a
/// cancellation request.
///
/// This object holds a oneshot channel that will be used to signal the natural
/// termination of the stream.
///
/// Our cancellation monitor task select on those two signals and complete when
/// either of them is completed.
struct
CancellationMonitor
{
ctx
:
Arc
<
StreamContext
>
,
// control plane sender
ctrl_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
MockNetworkControlEvents
>
,
// the cancellation mni
// as completed
finish_tx
:
tokio
::
sync
::
oneshot
::
Sender
<
()
>
,
}
impl
CancellationMonitor
{
async
fn
execute
(
self
)
{
// select on the finish_rx and the kill signal
let
ctx
=
self
.ctx
;
let
ctrl_tx
=
self
.ctrl_tx
;
let
mut
finish_tx
=
self
.finish_tx
;
tokio
::
select!
{
_
=
ctx
.stopped
()
=>
{
// todo(metrics): increment metric counter for cancelled requests
// send a cancellation request to the control plane
let
_
=
ctrl_tx
.send
(
MockNetworkControlEvents
::
Cancel
(
ctx
.id
()
.to_string
()))
.await
;
}
_
=
finish_tx
.closed
()
=>
{
// the stream has completed naturally
}
}
}
}
// held by the scan combinator
#[allow(dead_code)]
struct
ResponseMonitor
{
ctx
:
Arc
<
StreamContext
>
,
finish_rx
:
tokio
::
sync
::
oneshot
::
Receiver
<
()
>
,
}
runtime/rust/tests/common/mod.rs
0 → 100644
View file @
5ed8c1c0
pub
mod
engines
;
pub
mod
mock
;
runtime/rust/tests/lifecycle.rs
0 → 100644
View file @
5ed8c1c0
use
triton_distributed
::{
worker
::
Worker
,
Result
,
Runtime
};
async
fn
hello_world
(
_
runtime
:
Runtime
)
->
Result
<
()
>
{
Ok
(())
}
#[test]
fn
test_lifecycle
()
{
let
worker
=
Worker
::
from_settings
()
.unwrap
();
worker
.execute
(
hello_world
)
.unwrap
();
}
// async fn discoverable(runtime: Runtime) -> Result<()> {
// let config = DiscoveryConfig {
// etcd_url: vec!["http://localhost:2379".to_string()],
// etcd_connect_options: None,
// };
// let client = DiscoveryClient::new(config, runtime.clone()).await?;
// println!("Primary lease id: {:x}", client.lease_id());
// let lease = client.create_lease(60).await?;
// // Keys and values
// let lock_key = "lock_key"; // Key for the lock
// let object_key = "object_key"; // Key for the object
// let object_value = "This is the object value"; // Value for the object
// let lock_value = "locked"; // Value indicating a lock
// let put_options = Some(PutOptions::new().with_lease(lease.id()));
// // Build the transaction
// let txn = Txn::new()
// .when(vec![Compare::version(lock_key, CompareOp::Equal, 0)]) // Ensure the lock does not exist
// .and_then(vec![
// TxnOp::put(object_key, object_value, put_options.clone()), // Create the object
// TxnOp::put(lock_key, lock_value, put_options), // Set the lock
// ]);
// // Execute the transaction
// let txn_response = client.etc_client().kv_client().txn(txn).await?;
// tokio::spawn(async move {
// println!("custom lease id: {:x}", lease.id());
// lease.cancel_token().cancelled().await;
// println!("custom lease revoked");
// });
// runtime.child_token().cancelled().await;
// Ok(())
// }
// #[test]
// fn test_discovery_client() {
// let runtime = Runtime::new(RuntimeConfig::default()).unwrap();
// runtime.execute(discoverable).unwrap();
// }
runtime/rust/tests/pipeline.rs
0 → 100644
View file @
5ed8c1c0
use
futures
::{
stream
,
StreamExt
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::{
sync
::
Arc
,
time
::
Duration
};
use
triton_distributed
::
engine
::
ResponseStream
;
use
triton_distributed
::{
pipeline
::{
async_trait
,
AsyncEngine
,
Data
,
Event
,
ManyOut
,
Operator
,
ServiceBackend
,
ServiceEngine
,
ServiceFrontend
,
SingleIn
,
*
,
},
Error
,
};
mod
common
;
use
common
::
engines
::{
AsyncGenerator
,
LlmdbaEngine
as
LambdaEngine
};
use
common
::
mock
;
/// The [`super::engine::ResponseStream`] is annotated with the following types.
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
pub
enum
Annotated
<
T
:
Data
>
{
/// The primary data which expected to be returned.
Data
(
T
),
/// An actionable [`Event`] that can be handled.
Event
(
Event
),
/// Additional information or metadata produced by the pipeline.
Comment
(
String
),
/// An error produced by the pipeline. Multiple errors can be produced.
Error
(
String
),
/// A sentinel value to indicate the end of the stream. This should not be emitted publicly.
/// The implementation should be able to do the equivalent of a `.take_while` and trigger a
/// stop if detected.
End
,
}
/// An [`Operator`] is used when you want to transform both the input and output of a pipeline.
/// In this case, our operator will perform the preprocessing step, but also add an annotation
/// to the output stream
struct
PreprocesOperator
{}
#[async_trait]
impl
Operator
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>
,
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>
,
>
for
PreprocesOperator
{
async
fn
generate
(
&
self
,
req
:
SingleIn
<
String
>
,
next
:
Arc
<
dyn
AsyncEngine
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>
,
Error
>>
,
)
->
Result
<
ManyOut
<
Annotated
<
String
>>
,
Error
>
{
// capture some details about the request
let
prepend
=
vec!
[
Annotated
::
<
String
>
::
Comment
(
format!
(
"PreprocessOperator: {:?}"
,
req
))];
// we will append the result of this to the response stream via a chain
let
prepend_stream
=
stream
::
iter
(
prepend
);
// modify the request
let
req
=
req
.map
(|
x
|
format!
(
"{} from operator"
,
x
));
// issue the preprocessed request to the next engine
let
stream
=
next
.generate
(
req
)
.await
?
;
// capture the context of the response stream
let
ctx
=
stream
.context
();
// chain the prepend stream to the response stream
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
prepend_stream
.chain
(
stream
)),
ctx
,
))
}
}
fn
make_backend_engine
()
->
ServiceEngine
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
{
LambdaEngine
::
from_generator
(
AsyncGenerator
::
<
String
,
Annotated
<
String
>>
::
new
(
|(
req
,
stream
)|
async
move
{
let
chars
=
req
.chars
()
.collect
::
<
Vec
<
char
>>
();
for
c
in
chars
{
match
stream
.emit
(
Annotated
::
Data
(
c
.to_string
()))
.await
{
Ok
(
_
)
=>
{}
Err
(
_
)
=>
return
,
}
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
10
))
.await
;
}
},
))
}
#[tokio::test]
async
fn
test_service_source_sink
()
{
let
source
=
ServiceFrontend
::
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
::
new
();
let
sink
=
ServiceBackend
::
from_engine
(
make_backend_engine
());
let
service
=
source
.link
(
sink
)
.unwrap
()
.link
(
source
)
.unwrap
();
let
mut
stream
=
service
.generate
(
"test"
.to_string
()
.into
())
.await
.unwrap
();
let
mut
counter
=
0
;
while
let
Some
(
_
output
)
=
stream
.next
()
.await
{
counter
+=
1
;
}
assert_eq!
(
counter
,
4
);
}
fn
make_preprocessor
()
->
Arc
<
PipelineNode
<
SingleIn
<
String
>
,
SingleIn
<
String
>>>
{
PipelineNode
::
<
SingleIn
<
String
>
,
SingleIn
<
String
>>
::
new
(
Box
::
new
(|
req
|
{
Ok
(
req
.map
(|
x
|
format!
(
"{} world"
,
x
)))
}))
}
#[allow(clippy::type_complexity)]
fn
make_postprocessor
()
->
Arc
<
PipelineNode
<
ManyOut
<
Annotated
<
String
>>
,
ManyOut
<
Annotated
<
String
>>>>
{
PipelineNode
::
<
ManyOut
<
Annotated
<
String
>>
,
ManyOut
<
Annotated
<
String
>>>
::
new
(
Box
::
new
(|
req
|
{
let
ctx
=
req
.context
();
let
double_stream
=
req
.flat_map
(|
x
|
{
let
x1
=
x
.clone
();
let
x2
=
x
;
stream
::
iter
(
vec!
[
x1
,
x2
])
});
Ok
(
ResponseStream
::
new
(
Box
::
pin
(
double_stream
),
ctx
))
}))
}
// Node 0:
// [frontend] -------[pre processor]-----> [backend]
// [frontend] <----- [post processor] ---- [backend]
fn
make_service
(
)
->
Result
<
ServiceEngine
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
,
PipelineError
>
{
// Frontend - Callable interface
let
frontend
=
ServiceFrontend
::
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
::
new
();
// Mimics processing the prompt and tokenization
let
preprocess
=
make_preprocessor
();
// Mimics decoding; shows we can use any type of stream operation,
// e.g. map, flat_map, fold, scan, etc. to transform the response stream
let
postprocess
=
make_postprocessor
();
// Mimics backend streaming by emitting each character of the input string
let
backend
=
ServiceBackend
::
from_engine
(
make_backend_engine
());
// LLM Pipelines are build by linking the frontend to the backend for input handling
// then linking from the backend to the frontend for the output handling
let
service
=
frontend
.link
(
preprocess
)
?
.link
(
backend
)
?
.link
(
postprocess
)
?
.link
(
frontend
)
?
;
Ok
(
service
)
}
#[tokio::test]
async
fn
test_service_source_node_sink
()
{
let
service
=
make_service
()
.unwrap
();
let
mut
stream
=
service
.generate
(
"test"
.to_string
()
.into
())
.await
.unwrap
();
let
mut
counter
=
0
;
while
let
Some
(
_
output
)
=
stream
.next
()
.await
{
counter
+=
1
;
}
assert_eq!
(
counter
,
20
);
}
// Put the post process on node 0, but the preprocessor and the compute on node1
// Node 0:
// [frontend] ---------------------------> [segment_sink]
// [frontend] <----- [post processor] ---- [segment_sink]
//
// Node 1:
// [segment_source] ---- [preprocessor] ---> [backend]
// [segment_source] <----------------------- [backend]
#[tokio::test]
async
fn
test_disaggregated_service
()
{
println!
(
"Running test_disaggregated_service"
);
// Node 0
let
frontend
=
ServiceFrontend
::
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
::
new
();
let
postprocessor
=
make_postprocessor
();
let
end_node_0
=
SegmentSink
::
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
::
new
();
let
node0_service
=
frontend
.link
(
end_node_0
.clone
())
.unwrap
()
.link
(
postprocessor
)
.unwrap
()
.link
(
frontend
)
.unwrap
();
// Node 1
let
start_node1
=
SegmentSource
::
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
::
new
();
let
preprocessor
=
make_preprocessor
();
let
backend
=
ServiceBackend
::
from_engine
(
make_backend_engine
());
let
node1_service
=
start_node1
.link
(
preprocessor
)
.unwrap
()
.link
(
backend
)
.unwrap
()
.link
(
start_node1
.clone
())
.unwrap
();
let
opts
=
mock
::
MockNetworkOptions
::
default
();
let
(
egress
,
ingress
)
=
mock
::
MockNetworkTransport
::
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>
,
>
::
new_egress_ingress
(
opts
);
end_node_0
.attach
(
egress
)
.unwrap
();
ingress
.segment
(
node1_service
)
.unwrap
();
tokio
::
spawn
(
ingress
.execute
());
let
mut
stream
=
node0_service
.generate
(
"test"
.to_string
()
.into
())
.await
.unwrap
();
let
mut
counter
=
0
;
while
let
Some
(
_
output
)
=
stream
.next
()
.await
{
counter
+=
1
;
}
assert_eq!
(
counter
,
20
);
}
// Node 0:
// [frontend] --> [pre processor] --> [operator] ----------------------> [backend]
// [frontend] <---------------------- [operator] <--[post processor] <-- [backend]
fn
make_service_with_operator
(
)
->
Result
<
ServiceEngine
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
,
PipelineError
>
{
// Frontend - Callable interface
let
frontend
=
ServiceFrontend
::
<
SingleIn
<
String
>
,
ManyOut
<
Annotated
<
String
>>>
::
new
();
// Mimics processing the prompt and tokenization
let
preprocess
=
make_preprocessor
();
// Mimics decoding; shows we can use any type of stream operation,
// e.g. map, flat_map, fold, scan, etc. to transform the response stream
let
postprocess
=
make_postprocessor
();
// Mimics backend streaming by emitting each character of the input string
let
backend
=
ServiceBackend
::
from_engine
(
make_backend_engine
());
let
operator
=
PipelineOperator
::
new
(
Arc
::
new
(
PreprocesOperator
{}));
// LLM Pipelines are build by linking the frontend to the backend for input handling
// then linking from the backend to the frontend for the output handling
let
service
=
frontend
.link
(
preprocess
)
?
.link
(
operator
.forward_edge
())
?
.link
(
backend
)
?
.link
(
postprocess
)
?
.link
(
operator
.backward_edge
())
?
.link
(
frontend
)
?
;
Ok
(
service
)
}
#[tokio::test]
async
fn
test_service_source_node_sink_with_operator
()
{
let
service
=
make_service_with_operator
()
.unwrap
();
let
mut
stream
=
service
.generate
(
"test"
.to_string
()
.into
())
.await
.unwrap
();
let
mut
counter
=
0
;
let
mut
annotations_counter
=
0
;
while
let
Some
(
output
)
=
stream
.next
()
.await
{
match
output
{
Annotated
::
Data
(
_
)
=>
counter
+=
1
,
Annotated
::
Comment
(
_
)
=>
annotations_counter
+=
1
,
_
=>
{}
}
}
assert_eq!
(
annotations_counter
,
1
);
assert_eq!
(
counter
,
48
);
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment