Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
c4213899
Unverified
Commit
c4213899
authored
May 06, 2025
by
jthomson04
Committed by
GitHub
May 06, 2025
Browse files
feat: Migrate NATS Queue to Rust (#669) (#961)
parent
2d4f8b50
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
259 additions
and
78 deletions
+259
-78
examples/llm/utils/nats_queue.py
examples/llm/utils/nats_queue.py
+8
-78
lib/bindings/python/rust/lib.rs
lib/bindings/python/rust/lib.rs
+1
-0
lib/bindings/python/rust/llm.rs
lib/bindings/python/rust/llm.rs
+1
-0
lib/bindings/python/rust/llm/nats.rs
lib/bindings/python/rust/llm/nats.rs
+103
-0
lib/runtime/src/transports/nats.rs
lib/runtime/src/transports/nats.rs
+146
-0
No files found.
examples/llm/utils/nats_queue.py
View file @
c4213899
...
...
@@ -18,10 +18,7 @@ import asyncio
from
contextlib
import
asynccontextmanager
from
typing
import
ClassVar
,
Optional
from
nats.aio.client
import
Client
as
NATS
from
nats.errors
import
Error
as
NatsError
from
nats.js.client
import
JetStreamContext
from
nats.js.errors
import
NotFoundError
from
dynamo._core
import
NatsQueue
class
NATSQueue
:
...
...
@@ -34,15 +31,7 @@ class NATSQueue:
nats_server
:
str
=
"nats://localhost:4222"
,
dequeue_timeout
:
float
=
1
,
):
self
.
nats_url
=
nats_server
self
.
_nc
:
Optional
[
NATS
]
=
None
self
.
_js
:
Optional
[
JetStreamContext
]
=
None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self
.
_stream_name
=
stream_name
.
replace
(
"/"
,
"_"
).
replace
(
"
\\
"
,
"_"
)
self
.
_subject
=
f
"
{
self
.
_stream_name
}
.*"
self
.
dequeue_timeout
=
dequeue_timeout
self
.
_subscriber
:
Optional
[
JetStreamContext
.
PullSubscription
]
=
None
self
.
nats_q
=
NatsQueue
(
stream_name
,
nats_server
,
dequeue_timeout
)
@
classmethod
@
asynccontextmanager
...
...
@@ -81,79 +70,20 @@ class NATSQueue:
cls
.
_instance
=
None
async
def
connect
(
self
):
"""Establish connection and create stream if needed"""
try
:
if
self
.
_nc
is
None
:
self
.
_nc
=
NATS
()
await
self
.
_nc
.
connect
(
self
.
nats_url
)
self
.
_js
=
self
.
_nc
.
jetstream
()
# Check if stream exists, if not create it
try
:
await
self
.
_js
.
stream_info
(
self
.
_stream_name
)
except
NotFoundError
:
await
self
.
_js
.
add_stream
(
name
=
self
.
_stream_name
,
subjects
=
[
self
.
_subject
],
# TODO: make these configurable and add guide to set these values
max_bytes
=
1073741824
,
# 1GB total storage limit
max_msgs
=
1000000
,
# 1 million messages limit
)
# Create persistent subscriber
self
.
_subscriber
=
await
self
.
_js
.
pull_subscribe
(
f
"
{
self
.
_stream_name
}
.queue"
,
durable
=
"worker-group"
)
except
NatsError
as
e
:
await
self
.
close
()
raise
ConnectionError
(
f
"Failed to connect to NATS:
{
e
}
"
)
await
self
.
nats_q
.
connect
()
async
def
ensure_connection
(
self
):
"""Ensure we have an active connection"""
if
self
.
_nc
is
None
or
self
.
_nc
.
is_closed
:
await
self
.
connect
()
await
self
.
nats_q
.
ensure_connection
()
async
def
close
(
self
):
"""Close the connection when done"""
if
self
.
_nc
:
await
self
.
_nc
.
close
()
self
.
_nc
=
None
self
.
_js
=
None
self
.
_subscriber
=
None
await
self
.
nats_q
.
close
()
# TODO: is enqueue/dequeue_object a better name for a general queue?
async
def
enqueue_task
(
self
,
task_data
:
bytes
)
->
None
:
"""
Enqueue a task using msgspec-encoded data
"""
await
self
.
ensure_connection
()
try
:
await
self
.
_js
.
publish
(
f
"
{
self
.
_stream_name
}
.queue"
,
task_data
)
# type: ignore
except
NatsError
as
e
:
raise
RuntimeError
(
f
"Failed to enqueue task:
{
e
}
"
)
await
self
.
nats_q
.
enqueue_task
(
task_data
)
async
def
dequeue_task
(
self
)
->
Optional
[
bytes
]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await
self
.
ensure_connection
()
try
:
msgs
=
await
self
.
_subscriber
.
fetch
(
1
,
timeout
=
self
.
dequeue_timeout
)
# type: ignore
if
msgs
:
msg
=
msgs
[
0
]
await
msg
.
ack
()
return
msg
.
data
return
None
except
asyncio
.
TimeoutError
:
return
None
except
NatsError
as
e
:
raise
RuntimeError
(
f
"Failed to dequeue task:
{
e
}
"
)
return
await
self
.
nats_q
.
dequeue_task
()
async
def
get_queue_size
(
self
)
->
int
:
"""Get the number of messages currently in the queue"""
await
self
.
ensure_connection
()
try
:
# Get consumer info to get pending messages count
consumer_info
=
await
self
.
_js
.
consumer_info
(
# type: ignore
self
.
_stream_name
,
"worker-group"
)
# Return number of pending messages (real-time queue size)
return
consumer_info
.
num_pending
except
NatsError
as
e
:
raise
RuntimeError
(
f
"Failed to get queue size:
{
e
}
"
)
return
await
self
.
nats_q
.
get_queue_size
()
lib/bindings/python/rust/lib.rs
View file @
c4213899
...
...
@@ -74,6 +74,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m
.add_class
::
<
llm
::
kv
::
KvMetricsAggregator
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvEventPublisher
>
()
?
;
m
.add_class
::
<
llm
::
kv
::
KvRecorder
>
()
?
;
m
.add_class
::
<
llm
::
nats
::
NatsQueue
>
()
?
;
m
.add_class
::
<
http
::
HttpService
>
()
?
;
m
.add_class
::
<
http
::
HttpError
>
()
?
;
m
.add_class
::
<
http
::
HttpAsyncEngine
>
()
?
;
...
...
lib/bindings/python/rust/llm.rs
View file @
c4213899
...
...
@@ -42,4 +42,5 @@ pub mod backend;
pub
mod
disagg_router
;
pub
mod
kv
;
pub
mod
model_card
;
pub
mod
nats
;
pub
mod
preprocessor
;
lib/bindings/python/rust/llm/nats.rs
0 → 100644
View file @
c4213899
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
#[pyclass(subclass)]
pub
(
crate
)
struct
NatsQueue
{
inner
:
Arc
<
Mutex
<
crate
::
rs
::
transports
::
nats
::
NatsQueue
>>
,
}
#[pymethods]
impl
NatsQueue
{
#[new]
#[pyo3(signature
=
(stream_name,
nats_server,
dequeue_timeout))]
fn
new
(
stream_name
:
String
,
nats_server
:
String
,
dequeue_timeout
:
f64
)
->
PyResult
<
Self
>
{
let
inner
=
Arc
::
new
(
Mutex
::
new
(
crate
::
rs
::
transports
::
nats
::
NatsQueue
::
new
(
stream_name
,
nats_server
,
std
::
time
::
Duration
::
from_secs
(
dequeue_timeout
as
u64
),
)));
Ok
(
Self
{
inner
})
}
fn
connect
<
'p
>
(
&
mut
self
,
py
:
Python
<
'p
>
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
queue
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
queue
.lock
()
.await
.connect
()
.await
.map_err
(
to_pyerr
)
?
;
Ok
(())
})
}
fn
ensure_connection
<
'p
>
(
&
mut
self
,
py
:
Python
<
'p
>
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
queue
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
queue
.lock
()
.await
.ensure_connection
()
.await
.map_err
(
to_pyerr
)
?
;
Ok
(())
})
}
fn
close
<
'p
>
(
&
mut
self
,
py
:
Python
<
'p
>
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
queue
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
queue
.lock
()
.await
.close
()
.await
.map_err
(
to_pyerr
)
?
;
Ok
(())
})
}
fn
enqueue_task
<
'p
>
(
&
mut
self
,
py
:
Python
<
'p
>
,
task_data
:
Py
<
PyBytes
>
,
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
bytes
=
task_data
.as_bytes
(
py
)
.to_vec
();
let
queue
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
queue
.lock
()
.await
.enqueue_task
(
bytes
.into
())
.await
.map_err
(
to_pyerr
)
?
;
Ok
(())
})
}
fn
dequeue_task
<
'p
>
(
&
mut
self
,
py
:
Python
<
'p
>
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
queue
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
Ok
(
queue
.lock
()
.await
.dequeue_task
()
.await
.map_err
(
to_pyerr
)
?
.map
(|
bytes
|
bytes
.to_vec
()))
})
}
fn
get_queue_size
<
'p
>
(
&
mut
self
,
py
:
Python
<
'p
>
)
->
PyResult
<
Bound
<
'p
,
PyAny
>>
{
let
queue
=
self
.inner
.clone
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
queue
.lock
()
.await
.get_queue_size
()
.await
.map_err
(
to_pyerr
)
})
}
}
lib/runtime/src/transports/nats.rs
View file @
c4213899
...
...
@@ -341,6 +341,152 @@ pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
Ok
((
bucket
.to_string
(),
key
.to_string
()))
}
/// A queue implementation using NATS JetStream
pub
struct
NatsQueue
{
/// The name of the stream to use for the queue
stream_name
:
String
,
/// The NATS server URL
nats_server
:
String
,
/// Timeout for dequeue operations in seconds
dequeue_timeout
:
time
::
Duration
,
/// The NATS client
client
:
Option
<
Client
>
,
/// The subject pattern used for this queue
subject
:
String
,
/// The subscriber for pull-based consumption
subscriber
:
Option
<
jetstream
::
consumer
::
PullConsumer
>
,
}
impl
NatsQueue
{
/// Create a new NatsQueue with the given configuration
pub
fn
new
(
stream_name
:
String
,
nats_server
:
String
,
dequeue_timeout
:
time
::
Duration
)
->
Self
{
// Sanitize stream name to remove path separators (like in Python version)
let
sanitized_stream_name
=
stream_name
.replace
([
'/'
,
'\\'
],
"_"
);
let
subject
=
format!
(
"{}.*"
,
sanitized_stream_name
);
Self
{
stream_name
:
sanitized_stream_name
,
nats_server
,
dequeue_timeout
,
client
:
None
,
subject
,
subscriber
:
None
,
}
}
/// Connect to the NATS server and set up the stream and consumer
pub
async
fn
connect
(
&
mut
self
)
->
Result
<
()
>
{
if
self
.client
.is_none
()
{
// Create a new client
let
client_options
=
Client
::
builder
()
.server
(
self
.nats_server
.clone
())
.build
()
?
;
let
client
=
client_options
.connect
()
.await
?
;
// Check if stream exists, if not create it
let
streams
=
client
.list_streams
()
.await
?
;
if
!
streams
.contains
(
&
self
.stream_name
)
{
log
::
debug!
(
"Creating NATS stream {}"
,
self
.stream_name
);
let
stream_config
=
jetstream
::
stream
::
Config
{
name
:
self
.stream_name
.clone
(),
subjects
:
vec!
[
self
.subject
.clone
()],
..
Default
::
default
()
};
client
.jetstream
()
.create_stream
(
stream_config
)
.await
?
;
}
// Create persistent subscriber
let
consumer_config
=
jetstream
::
consumer
::
pull
::
Config
{
durable_name
:
Some
(
"worker-group"
.to_string
()),
..
Default
::
default
()
};
let
stream
=
client
.jetstream
()
.get_stream
(
&
self
.stream_name
)
.await
?
;
let
subscriber
=
stream
.create_consumer
(
consumer_config
)
.await
?
;
self
.subscriber
=
Some
(
subscriber
);
self
.client
=
Some
(
client
);
}
Ok
(())
}
/// Ensure we have an active connection
pub
async
fn
ensure_connection
(
&
mut
self
)
->
Result
<
()
>
{
if
self
.client
.is_none
()
{
self
.connect
()
.await
?
;
}
Ok
(())
}
/// Close the connection when done
pub
async
fn
close
(
&
mut
self
)
->
Result
<
()
>
{
self
.subscriber
=
None
;
self
.client
=
None
;
Ok
(())
}
/// Enqueue a task using the provided data
pub
async
fn
enqueue_task
(
&
mut
self
,
task_data
:
Bytes
)
->
Result
<
()
>
{
self
.ensure_connection
()
.await
?
;
if
let
Some
(
client
)
=
&
self
.client
{
let
subject
=
format!
(
"{}.queue"
,
self
.stream_name
);
client
.jetstream
()
.publish
(
subject
,
task_data
)
.await
?
;
Ok
(())
}
else
{
Err
(
anyhow
::
anyhow!
(
"Client not connected"
))
}
}
/// Dequeue and return a task as raw bytes
pub
async
fn
dequeue_task
(
&
mut
self
)
->
Result
<
Option
<
Bytes
>>
{
self
.ensure_connection
()
.await
?
;
if
let
Some
(
subscriber
)
=
&
self
.subscriber
{
let
mut
batch
=
subscriber
.fetch
()
.expires
(
self
.dequeue_timeout
)
.max_messages
(
1
)
.messages
()
.await
?
;
if
let
Some
(
message
)
=
batch
.next
()
.await
{
let
message
=
message
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to get message: {}"
,
e
))
?
;
message
.ack
()
.await
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to ack message: {}"
,
e
))
?
;
Ok
(
Some
(
message
.payload
.clone
()))
}
else
{
Ok
(
None
)
}
}
else
{
Err
(
anyhow
::
anyhow!
(
"Subscriber not initialized"
))
}
}
/// Get the number of messages currently in the queue
pub
async
fn
get_queue_size
(
&
mut
self
)
->
Result
<
u64
>
{
self
.ensure_connection
()
.await
?
;
if
let
Some
(
client
)
=
&
self
.client
{
// Get consumer info to get pending messages count
let
stream
=
client
.jetstream
()
.get_stream
(
&
self
.stream_name
)
.await
?
;
let
mut
consumer
:
jetstream
::
consumer
::
PullConsumer
=
stream
.get_consumer
(
"worker-group"
)
.await
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to get consumer: {}"
,
e
))
?
;
let
info
=
consumer
.info
()
.await
?
;
Ok
(
info
.num_pending
)
}
else
{
Err
(
anyhow
::
anyhow!
(
"Client not connected"
))
}
}
}
#[cfg(test)]
mod
tests
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment