Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
1af7433b
Commit
1af7433b
authored
Mar 05, 2025
by
Neelay Shah
Committed by
GitHub
Mar 05, 2025
Browse files
refactor: rename triton_distributed to dynemo (#22)
Co-authored-by:
Graham King
<
grahamk@nvidia.com
>
parent
ee4ef06b
Changes
165
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
383 additions
and
649 deletions
+383
-649
.dockerignore
.dockerignore
+1
-1
.github/workflows/pre-merge.yml
.github/workflows/pre-merge.yml
+0
-57
ATTRIBUTIONS.md
ATTRIBUTIONS.md
+1
-1
README.md
README.md
+1
-1
applications/llm/count/Cargo.lock
applications/llm/count/Cargo.lock
+95
-95
applications/llm/count/Cargo.toml
applications/llm/count/Cargo.toml
+3
-3
applications/llm/count/README.md
applications/llm/count/README.md
+6
-6
applications/llm/count/src/bin/mock_worker.rs
applications/llm/count/src/bin/mock_worker.rs
+5
-5
applications/llm/count/src/lib.rs
applications/llm/count/src/lib.rs
+4
-6
applications/llm/count/src/main.rs
applications/llm/count/src/main.rs
+3
-3
container/Dockerfile
container/Dockerfile
+17
-17
container/Dockerfile.vllm
container/Dockerfile.vllm
+18
-18
container/Dockerfile.vllm_nixl
container/Dockerfile.vllm_nixl
+18
-18
container/deps/clone_tensorrtllm.sh
container/deps/clone_tensorrtllm.sh
+5
-5
container/deps/vllm/vllm_v0.7.2-dynemo-kv-disagg-patch.patch
container/deps/vllm/vllm_v0.7.2-dynemo-kv-disagg-patch.patch
+188
-396
deploy/compoundai/sdk/src/compoundai/cli/serve_nova.py
deploy/compoundai/sdk/src/compoundai/cli/serve_nova.py
+6
-5
deploy/compoundai/sdk/src/compoundai/sdk/decorators.py
deploy/compoundai/sdk/src/compoundai/sdk/decorators.py
+1
-1
deploy/compoundai/sdk/src/compoundai/sdk/dependency.py
deploy/compoundai/sdk/src/compoundai/sdk/dependency.py
+2
-2
examples/python_rs/llm/tensorrt_llm/README.md
examples/python_rs/llm/tensorrt_llm/README.md
+6
-6
examples/python_rs/llm/tensorrt_llm/common/client.py
examples/python_rs/llm/tensorrt_llm/common/client.py
+3
-3
No files found.
.dockerignore
View file @
1af7433b
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
**/*.plan
**/*.plan
**/.cache/*
**/.cache/*
**/*onnx*
**/*onnx*
# Engine must be allowed because code contains
triton_distributed
_engine.py
# Engine must be allowed because code contains
dynemo
_engine.py
**/*tensorrtllm_engines*
**/*tensorrtllm_engines*
**/*tensorrtllm_models*
**/*tensorrtllm_models*
**/*tensorrtllm_checkpoints*
**/*tensorrtllm_checkpoints*
...
...
.github/workflows/pre-merge.yml
View file @
1af7433b
...
@@ -22,25 +22,6 @@ on:
...
@@ -22,25 +22,6 @@ on:
jobs
:
jobs
:
# icp_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: ./icp/protos/gen_python.sh
# - run: pytest --verbose icp
# timeout-minutes: 3
pre-commit
:
pre-commit
:
runs-on
:
ubuntu-latest
runs-on
:
ubuntu-latest
permissions
:
permissions
:
...
@@ -52,41 +33,3 @@ jobs:
...
@@ -52,41 +33,3 @@ jobs:
timeout-minutes
:
3
timeout-minutes
:
3
# providers_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# PROTO_OUT: /python/icp/protos
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: pytest --verbose providers
# worker_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# PROTO_OUT: /python/icp/protos
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: ./icp/protos/gen_python.sh
# - run: pytest -p no:warnings --verbose worker/python/tests
# timeout-minutes: 2
ATTRIBUTIONS.md
View file @
1af7433b
...
@@ -17,7 +17,7 @@ limitations under the License.
...
@@ -17,7 +17,7 @@ limitations under the License.
# Open Source License Attribution
# Open Source License Attribution
Triton Distributed
uses Open Source components. You can find the details of these open-source projects along with license information below.
Dynemo
uses Open Source components. You can find the details of these open-source projects along with license information below.
We are grateful to the developers for their contributions to open source and acknowledge these below.
We are grateful to the developers for their contributions to open source and acknowledge these below.
## nats-py - [Apache License 2.0](https://github.com/nats-io/nats.py/blob/main/LICENSE)
## nats-py - [Apache License 2.0](https://github.com/nats-io/nats.py/blob/main/LICENSE)
...
...
README.md
View file @
1af7433b
...
@@ -71,7 +71,7 @@ The run script offers a few common workflows:
...
@@ -71,7 +71,7 @@ The run script offers a few common workflows:
1.
Running a command in a container and exiting.
1.
Running a command in a container and exiting.
```
```
./container/run.sh -- python3 -c "import
triton_distributed.runtime; help(triton_distributed
.runtime)"
./container/run.sh -- python3 -c "import
dynemo.runtime; help(dynemo
.runtime)"
```
```
2.
Starting an interactive shell.
2.
Starting an interactive shell.
...
...
applications/llm/count/Cargo.lock
View file @
1af7433b
...
@@ -737,6 +737,8 @@ version = "0.1.0"
...
@@ -737,6 +737,8 @@ version = "0.1.0"
dependencies = [
dependencies = [
"axum 0.6.20",
"axum 0.6.20",
"clap",
"clap",
"dynemo-llm",
"dynemo-runtime",
"opentelemetry",
"opentelemetry",
"opentelemetry-prometheus",
"opentelemetry-prometheus",
"prometheus",
"prometheus",
...
@@ -747,8 +749,6 @@ dependencies = [
...
@@ -747,8 +749,6 @@ dependencies = [
"thiserror 1.0.69",
"thiserror 1.0.69",
"tokio",
"tokio",
"tracing",
"tracing",
"triton-distributed-llm",
"triton-distributed-runtime",
]
]
[[package]]
[[package]]
...
@@ -1024,6 +1024,99 @@ dependencies = [
...
@@ -1024,6 +1024,99 @@ dependencies = [
"syn 2.0.98",
"syn 2.0.98",
]
]
[[package]]
name = "dynemo-llm"
version = "0.2.1"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
"bindgen",
"blake3",
"bs62",
"bytes",
"chrono",
"cmake",
"derive_builder",
"dynemo-runtime",
"either",
"erased-serde",
"futures",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"serde_repr",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"unicode-segmentation",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "dynemo-runtime"
version = "0.2.1"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
[[package]]
name = "ed25519"
name = "ed25519"
version = "2.2.3"
version = "2.2.3"
...
@@ -4232,99 +4325,6 @@ dependencies = [
...
@@ -4232,99 +4325,6 @@ dependencies = [
"tracing-serde",
"tracing-serde",
]
]
[[package]]
name = "triton-distributed-llm"
version = "0.2.1"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
"bindgen",
"blake3",
"bs62",
"bytes",
"chrono",
"cmake",
"derive_builder",
"either",
"erased-serde",
"futures",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"serde_repr",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"triton-distributed-runtime",
"unicode-segmentation",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "triton-distributed-runtime"
version = "0.2.1"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
[[package]]
name = "try-lock"
name = "try-lock"
version = "0.2.5"
version = "0.2.5"
...
...
applications/llm/count/Cargo.toml
View file @
1af7433b
...
@@ -21,8 +21,8 @@ license = "Apache-2.0"
...
@@ -21,8 +21,8 @@ license = "Apache-2.0"
[dependencies]
[dependencies]
# local
# local
triton-distributed
-runtime
=
{
path
=
"../../../lib/runtime"
}
dynemo
-runtime
=
{
path
=
"../../../lib/runtime"
}
triton-distributed
-llm
=
{
path
=
"../../../lib/llm"
}
dynemo
-llm
=
{
path
=
"../../../lib/llm"
}
# workspace - todo
# workspace - todo
...
...
applications/llm/count/README.md
View file @
1af7433b
...
@@ -8,17 +8,17 @@ the services associated with that endpoint, do some postprocessing on them,
...
@@ -8,17 +8,17 @@ the services associated with that endpoint, do some postprocessing on them,
and then publish an event with the postprocessed data.
and then publish an event with the postprocessed data.
```
bash
```
bash
# For more details, try
TRD
_LOG=debug
# For more details, try
DYN
_LOG=debug
TRD
_LOG
=
info cargo run
--bin
count
--
--namespace
triton-init
--component
backend
--endpoint
generate
DYN
_LOG
=
info cargo run
--bin
count
--
--namespace
dynemo
--component
backend
--endpoint
generate
# 2025-02-26T18:45:05.467026Z INFO count: Creating unique instance of Count at
triton-init
/components/count/instance
# 2025-02-26T18:45:05.467026Z INFO count: Creating unique instance of Count at
dynemo
/components/count/instance
# 2025-02-26T18:45:05.472146Z INFO count: Scraping service
triton
_init_backend_720278f8 and filtering on subject
triton
_init_backend_720278f8.generate
# 2025-02-26T18:45:05.472146Z INFO count: Scraping service
dynemo
_init_backend_720278f8 and filtering on subject
dynemo
_init_backend_720278f8.generate
# ...
# ...
```
```
With no matching endpoints running, you should see warnings in the logs:
With no matching endpoints running, you should see warnings in the logs:
```
bash
```
bash
2025-02-26T18:45:06.474161Z WARN count: No endpoints found matching subject
triton
_init_backend_720278f8.generate
2025-02-26T18:45:06.474161Z WARN count: No endpoints found matching subject
dynemo
_init_backend_720278f8.generate
```
```
To see metrics published to a matching endpoint, you can use the
To see metrics published to a matching endpoint, you can use the
...
@@ -35,7 +35,7 @@ since the endpoint will automatically get discovered.
...
@@ -35,7 +35,7 @@ since the endpoint will automatically get discovered.
When stats are found from the target endpoints being listened on, count will
When stats are found from the target endpoints being listened on, count will
aggregate and publish some metrics as both an event and to a prometheus web server:
aggregate and publish some metrics as both an event and to a prometheus web server:
```
```
2025-02-28T04:05:58.077901Z INFO count: Aggregated metrics: ProcessedEndpoints { endpoints: [Endpoint { name: "worker-7587884888253033398", subject: "
triton
_init_backend_720278f8.generate-694d951a80e06bb6", data: ForwardPassMetrics { request_active_slots: 58, request_total_slots: 100, kv_active_blocks: 77, kv_total_blocks: 100 } }, Endpoint { name: "worker-7587884888253033401", subject: "
triton
_init_backend_720278f8.generate-694d951a80e06bb9", data: ForwardPassMetrics { request_active_slots: 71, request_total_slots: 100, kv_active_blocks: 29, kv_total_blocks: 100 } }], worker_ids: [7587884888253033398, 7587884888253033401], load_avg: 53.0, load_std: 24.0 }
2025-02-28T04:05:58.077901Z INFO count: Aggregated metrics: ProcessedEndpoints { endpoints: [Endpoint { name: "worker-7587884888253033398", subject: "
dynemo
_init_backend_720278f8.generate-694d951a80e06bb6", data: ForwardPassMetrics { request_active_slots: 58, request_total_slots: 100, kv_active_blocks: 77, kv_total_blocks: 100 } }, Endpoint { name: "worker-7587884888253033401", subject: "
dynemo
_init_backend_720278f8.generate-694d951a80e06bb9", data: ForwardPassMetrics { request_active_slots: 71, request_total_slots: 100, kv_active_blocks: 29, kv_total_blocks: 100 } }], worker_ids: [7587884888253033398, 7587884888253033401], load_avg: 53.0, load_std: 24.0 }
```
```
To see the metrics being published in prometheus format, you can run:
To see the metrics being published in prometheus format, you can run:
...
...
applications/llm/count/src/bin/mock_worker.rs
View file @
1af7433b
...
@@ -13,10 +13,8 @@
...
@@ -13,10 +13,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
rand
::
Rng
;
use
dynemo_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
std
::
sync
::
Arc
;
use
dynemo_runtime
::{
use
triton_distributed_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
triton_distributed_runtime
::{
logging
,
logging
,
pipeline
::{
pipeline
::{
async_trait
,
network
::
Ingress
,
AsyncEngine
,
AsyncEngineContextProvider
,
Error
,
ManyOut
,
async_trait
,
network
::
Ingress
,
AsyncEngine
,
AsyncEngineContextProvider
,
Error
,
ManyOut
,
...
@@ -25,6 +23,8 @@ use triton_distributed_runtime::{
...
@@ -25,6 +23,8 @@ use triton_distributed_runtime::{
protocols
::
annotated
::
Annotated
,
protocols
::
annotated
::
Annotated
,
stream
,
DistributedRuntime
,
Result
,
Runtime
,
Worker
,
stream
,
DistributedRuntime
,
Result
,
Runtime
,
Worker
,
};
};
use
rand
::
Rng
;
use
std
::
sync
::
Arc
;
fn
main
()
->
Result
<
()
>
{
fn
main
()
->
Result
<
()
>
{
logging
::
init
();
logging
::
init
();
...
@@ -69,7 +69,7 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
...
@@ -69,7 +69,7 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
// we must first create a service, then we can attach one more more endpoints
// we must first create a service, then we can attach one more more endpoints
runtime
runtime
.namespace
(
"
triton-init
"
)
?
.namespace
(
"
dynemo
"
)
?
.component
(
"backend"
)
?
.component
(
"backend"
)
?
.service_builder
()
.service_builder
()
.create
()
.create
()
...
...
applications/llm/count/src/lib.rs
View file @
1af7433b
...
@@ -20,13 +20,11 @@ use prometheus::register_gauge_vec;
...
@@ -20,13 +20,11 @@ use prometheus::register_gauge_vec;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
net
::
SocketAddr
;
use
std
::
net
::
SocketAddr
;
use
triton_distributed
_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
dynemo
_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
triton_distributed
_llm
::
kv_router
::
scheduler
::
Endpoint
;
use
dynemo
_llm
::
kv_router
::
scheduler
::
Endpoint
;
use
triton_distributed
_llm
::
kv_router
::
scoring
::
ProcessedEndpoints
;
use
dynemo
_llm
::
kv_router
::
scoring
::
ProcessedEndpoints
;
use
triton_distributed_runtime
::{
use
dynemo_runtime
::{
distributed
::
Component
,
service
::
EndpointInfo
,
utils
::
Duration
,
Result
};
distributed
::
Component
,
service
::
EndpointInfo
,
utils
::
Duration
,
Result
,
};
/// Configuration for LLM worker load capacity metrics
/// Configuration for LLM worker load capacity metrics
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
...
...
applications/llm/count/src/main.rs
View file @
1af7433b
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
//! - KV Cache Blocks: [Active, Total]
//! - KV Cache Blocks: [Active, Total]
use
clap
::
Parser
;
use
clap
::
Parser
;
use
triton_distributed
_runtime
::{
use
dynemo
_runtime
::{
error
,
logging
,
error
,
logging
,
traits
::
events
::
EventPublisher
,
traits
::
events
::
EventPublisher
,
utils
::{
Duration
,
Instant
},
utils
::{
Duration
,
Instant
},
...
@@ -50,7 +50,7 @@ struct Args {
...
@@ -50,7 +50,7 @@ struct Args {
endpoint
:
String
,
endpoint
:
String
,
/// Namespace to operate in
/// Namespace to operate in
#[arg(long,
env
=
"
TRD
_NAMESPACE"
,
default_value
=
"
triton-init
"
)]
#[arg(long,
env
=
"
DYN
_NAMESPACE"
,
default_value
=
"
dynemo
"
)]
namespace
:
String
,
namespace
:
String
,
/// Polling interval in seconds (minimum 1 second)
/// Polling interval in seconds (minimum 1 second)
...
@@ -155,7 +155,7 @@ mod tests {
...
@@ -155,7 +155,7 @@ mod tests {
#[test]
#[test]
fn
test_namespace_from_env
()
{
fn
test_namespace_from_env
()
{
env
::
set_var
(
"
TRD
_NAMESPACE"
,
"test-namespace"
);
env
::
set_var
(
"
DYN
_NAMESPACE"
,
"test-namespace"
);
let
args
=
Args
::
parse_from
([
"count"
,
"--component"
,
"comp"
,
"--endpoint"
,
"end"
]);
let
args
=
Args
::
parse_from
([
"count"
,
"--component"
,
"comp"
,
"--endpoint"
,
"end"
]);
assert_eq!
(
args
.namespace
,
"test-namespace"
);
assert_eq!
(
args
.namespace
,
"test-namespace"
);
}
}
...
...
container/Dockerfile
View file @
1af7433b
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
ARG
BASE_IMAGE="nvcr.io/nvidia/tritonserver"
ARG
BASE_IMAGE="nvcr.io/nvidia/tritonserver"
ARG
BASE_IMAGE_TAG="25.01-py3"
ARG
BASE_IMAGE_TAG="25.01-py3"
FROM
${BASE_IMAGE}:${BASE_IMAGE_TAG} AS
triton-distributed
FROM
${BASE_IMAGE}:${BASE_IMAGE_TAG} AS
dynemo
# TODO: non root user by default
# TODO: non root user by default
...
@@ -34,7 +34,7 @@ RUN rustup toolchain install 1.85.0-x86_64-unknown-linux-gnu
...
@@ -34,7 +34,7 @@ RUN rustup toolchain install 1.85.0-x86_64-unknown-linux-gnu
# Install OpenAI-compatible frontend and its dependencies from triton server
# Install OpenAI-compatible frontend and its dependencies from triton server
# repository. These are used to have a consistent interface, schema, and FastAPI
# repository. These are used to have a consistent interface, schema, and FastAPI
# app between Triton Core and
Triton Distributed
implementations.
# app between Triton Core and
Dynemo
implementations.
ARG
OPENAI_SERVER_TAG="r25.01"
ARG
OPENAI_SERVER_TAG="r25.01"
RUN
mkdir
-p
/opt/tritonserver/python
&&
\
RUN
mkdir
-p
/opt/tritonserver/python
&&
\
cd
/opt/tritonserver/python
&&
\
cd
/opt/tritonserver/python
&&
\
...
@@ -78,7 +78,7 @@ ARG TENSORRTLLM_SKIP_CLONE=
...
@@ -78,7 +78,7 @@ ARG TENSORRTLLM_SKIP_CLONE=
ENV
FRAMEWORK=${FRAMEWORK}
ENV
FRAMEWORK=${FRAMEWORK}
RUN
--mount
=
type
=
bind
,source
=
./container/deps/requirements.tensorrtllm.txt,target
=
/tmp/requirements.txt
\
RUN
--mount
=
type
=
bind
,source
=
./container/deps/requirements.tensorrtllm.txt,target
=
/tmp/requirements.txt
\
--mount
=
type
=
bind
,source
=
./container/deps/clone_tensorrtllm.sh,target
=
/tmp/clone_tensorrtllm.sh
\
--mount
=
type
=
bind
,source
=
./container/deps/clone_tensorrtllm.sh,target
=
/tmp/clone_tensorrtllm.sh
\
if
[[
"
$FRAMEWORK
"
==
"TENSORRTLLM"
]]
;
then
pip
install
--timeout
=
2000
-r
/tmp/requirements.txt
;
if
[
${
TENSORRTLLM_SKIP_CLONE
}
-ne
1
]
;
then
/tmp/clone_tensorrtllm.sh
--tensorrtllm-backend-repo-tag
${
TENSORRTLLM_BACKEND_REPO_TAG
}
--tensorrtllm-backend-rebuild
${
TENSORRTLLM_BACKEND_REBUILD
}
--triton-llm-path
/opt/
triton
/llm_binding
;
fi
;
fi
if
[[
"
$FRAMEWORK
"
==
"TENSORRTLLM"
]]
;
then
pip
install
--timeout
=
2000
-r
/tmp/requirements.txt
;
if
[
${
TENSORRTLLM_SKIP_CLONE
}
-ne
1
]
;
then
/tmp/clone_tensorrtllm.sh
--tensorrtllm-backend-repo-tag
${
TENSORRTLLM_BACKEND_REPO_TAG
}
--tensorrtllm-backend-rebuild
${
TENSORRTLLM_BACKEND_REBUILD
}
--triton-llm-path
/opt/
dynemo
/llm_binding
;
fi
;
fi
RUN
--mount
=
type
=
bind
,source
=
./container/deps/requirements.standard.txt,target
=
/tmp/requirements.txt
\
RUN
--mount
=
type
=
bind
,source
=
./container/deps/requirements.standard.txt,target
=
/tmp/requirements.txt
\
...
@@ -106,7 +106,7 @@ ENV VLLM_GENERATE_WORKERS=${VLLM_FRAMEWORK:+1}
...
@@ -106,7 +106,7 @@ ENV VLLM_GENERATE_WORKERS=${VLLM_FRAMEWORK:+1}
ENV
VLLM_BASELINE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_BASELINE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_CONTEXT_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_CONTEXT_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_GENERATE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_GENERATE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_KV_CAPI_PATH="/opt/
triton
/llm_binding/lib/lib
triton
_llm_capi.so"
ENV
VLLM_KV_CAPI_PATH="/opt/
dynemo
/llm_binding/lib/lib
dynemo
_llm_capi.so"
ENV
PYTHONUNBUFFERED=1
ENV
PYTHONUNBUFFERED=1
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
...
@@ -159,27 +159,27 @@ COPY lib/bindings /workspace/lib/bindings
...
@@ -159,27 +159,27 @@ COPY lib/bindings /workspace/lib/bindings
RUN
cd
lib/bindings/c/
&&
\
RUN
cd
lib/bindings/c/
&&
\
cargo build
--release
--locked
&&
cargo doc
--no-deps
cargo build
--release
--locked
&&
cargo doc
--no-deps
# Install uv, create virtualenv for general use, and build
triton_distributed
wheel
# Install uv, create virtualenv for general use, and build
dynemo
wheel
COPY
--from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
COPY
--from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN
mkdir
/opt/
triton
&&
\
RUN
mkdir
/opt/
dynemo
&&
\
uv venv /opt/
triton
/venv
--python
3.12
&&
\
uv venv /opt/
dynemo
/venv
--python
3.12
&&
\
source
/opt/
triton
/venv/bin/activate
&&
\
source
/opt/
dynemo
/venv/bin/activate
&&
\
uv build
--wheel
--out-dir
/workspace/dist
&&
\
uv build
--wheel
--out-dir
/workspace/dist
&&
\
uv pip
install
/workspace/dist/
triton_distributed
*
cp312
*
.whl
uv pip
install
/workspace/dist/
dynemo
*
cp312
*
.whl
# Package the bindings
# Package the bindings
RUN
mkdir
-p
/opt/
triton
/bindings/wheels
&&
\
RUN
mkdir
-p
/opt/
dynemo
/bindings/wheels
&&
\
mkdir
/opt/
triton
/bindings/lib
&&
\
mkdir
/opt/
dynemo
/bindings/lib
&&
\
cp
dist/
triton_distributed
*
cp312
*
.whl /opt/
triton
/bindings/wheels/.
&&
\
cp
dist/
dynemo
*
cp312
*
.whl /opt/
dynemo
/bindings/wheels/.
&&
\
cp
lib/bindings/c/target/release/lib
triton_distributed
_llm_capi.so /opt/
triton
/bindings/lib/.
&&
\
cp
lib/bindings/c/target/release/lib
dynemo
_llm_capi.so /opt/
dynemo
/bindings/lib/.
&&
\
cp
-r
lib/bindings/c/include /opt/
triton
/bindings/.
cp
-r
lib/bindings/c/include /opt/
dynemo
/bindings/.
# Install
triton_distributed_runtime and triton_distributed_
llm wheels globally in container for tests that
# Install
dynemo.runtime and dynemo.
llm wheels globally in container for tests that
# currently run without virtual environment activated.
# currently run without virtual environment activated.
# TODO: In future, we may use a virtualenv for everything and remove this.
# TODO: In future, we may use a virtualenv for everything and remove this.
RUN
pip
install
/opt/
triton
/bindings/wheels/
triton_distributed
*
cp312
*
.whl
RUN
pip
install
/opt/
dynemo
/bindings/wheels/
dynemo
*
cp312
*
.whl
# Copy everything in after install steps to avoid re-running build/install
# Copy everything in after
g
install steps to avoid re-running build/install
# commands on unrelated changes in other dirs.
# commands on unrelated changes in other dirs.
COPY
. /workspace
COPY
. /workspace
...
...
container/Dockerfile.vllm
View file @
1af7433b
...
@@ -24,17 +24,17 @@ ENV PATH=/usr/local/bin/etcd/:$PATH
...
@@ -24,17 +24,17 @@ ENV PATH=/usr/local/bin/etcd/:$PATH
# Install uv and create virtualenv
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/
triton
&& \
RUN mkdir /opt/
dynemo
&& \
uv venv /opt/
triton
/venv --python 3.12
uv venv /opt/
dynemo
/venv --python 3.12
# Activate virtual environment
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/
triton
/venv
ENV VIRTUAL_ENV=/opt/
dynemo
/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-
triton
-kv-disagg-patch.patch"
ARG VLLM_PATCH="vllm_${VLLM_REF}-
dynemo
-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
...
@@ -100,25 +100,25 @@ COPY lib/bindings /workspace/lib/bindings
...
@@ -100,25 +100,25 @@ COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c && \
RUN cd lib/bindings/c && \
cargo build --release --locked && cargo doc --no-deps
cargo build --release --locked && cargo doc --no-deps
# Build
triton_distributed
wheel
# Build
dynemo
wheel
RUN source /opt/
triton
/venv/bin/activate && \
RUN source /opt/
dynemo
/venv/bin/activate && \
uv build --wheel --out-dir /workspace/dist && \
uv build --wheel --out-dir /workspace/dist && \
uv pip install /workspace/dist/
triton_distributed
*cp312*.whl
uv pip install /workspace/dist/
dynemo
*cp312*.whl
# Package the bindings
# Package the bindings
RUN mkdir -p /opt/
triton
/bindings/wheels && \
RUN mkdir -p /opt/
dynemo
/bindings/wheels && \
mkdir /opt/
triton
/bindings/lib && \
mkdir /opt/
dynemo
/bindings/lib && \
cp dist/
triton_distributed
*cp312*.whl /opt/
triton
/bindings/wheels/. && \
cp dist/
dynemo
*cp312*.whl /opt/
dynemo
/bindings/wheels/. && \
cp lib/bindings/c/target/release/lib
triton_distributed
_llm_capi.so /opt/
triton
/bindings/lib/. && \
cp lib/bindings/c/target/release/lib
dynemo
_llm_capi.so /opt/
dynemo
/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/
triton
/bindings/.
cp -r lib/bindings/c/include /opt/
dynemo
/bindings/.
# Tell vllm to use the
Triton
LLM C API for KV Cache Routing
# Tell vllm to use the
Dynemo
LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/
triton
/bindings/lib/lib
triton_distributed
_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/
dynemo
/bindings/lib/lib
dynemo
_llm_capi.so"
# FIXME: Copy more specific folders in for dev/debug after directory restructure
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
COPY . /workspace
# FIXME: May want a modification with
triton
-distributed banner on entry
# FIXME: May want a modification with
dynemo
-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
CMD []
...
@@ -136,10 +136,10 @@ RUN apt update -y && \
...
@@ -136,10 +136,10 @@ RUN apt update -y && \
echo "set -g mouse on" >> /root/.tmux.conf
echo "set -g mouse on" >> /root/.tmux.conf
# Set environment variables
# Set environment variables
ENV VIRTUAL_ENV=/opt/
triton
/venv
ENV VIRTUAL_ENV=/opt/
dynemo
/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV VLLM_KV_CAPI_PATH="/opt/
triton
/bindings/lib/lib
triton_distributed
_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/
dynemo
/bindings/lib/lib
dynemo
_llm_capi.so"
# Copy binaries
# Copy binaries
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
...
@@ -166,7 +166,7 @@ COPY examples/python_rs/llm/vllm /workspace/examples/python_rs/llm/vllm
...
@@ -166,7 +166,7 @@ COPY examples/python_rs/llm/vllm /workspace/examples/python_rs/llm/vllm
WORKDIR /workspace
WORKDIR /workspace
# FIXME: May want a modification with
triton
-distributed banner on entry
# FIXME: May want a modification with
dynemo
-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
CMD []
container/Dockerfile.vllm_nixl
View file @
1af7433b
...
@@ -150,17 +150,17 @@ ENV PATH=/usr/local/bin/etcd/:$PATH
...
@@ -150,17 +150,17 @@ ENV PATH=/usr/local/bin/etcd/:$PATH
# Install uv and create virtualenv
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/
triton
&& \
RUN mkdir /opt/
dynemo
&& \
uv venv /opt/
triton
/venv --python 3.12
uv venv /opt/
dynemo
/venv --python 3.12
# Activate virtual environment
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/
triton
/venv
ENV VIRTUAL_ENV=/opt/
dynemo
/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-
triton
-kv-disagg-patch.patch"
ARG VLLM_PATCH="vllm_${VLLM_REF}-
dynemo
-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
...
@@ -225,25 +225,25 @@ COPY lib/bindings /workspace/lib/bindings
...
@@ -225,25 +225,25 @@ COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c && \
RUN cd lib/bindings/c && \
cargo build --release --locked && cargo doc --no-deps
cargo build --release --locked && cargo doc --no-deps
# Build
triton_distributed
wheel
# Build
dynemo
wheel
RUN source /opt/
triton
/venv/bin/activate && \
RUN source /opt/
dynemo
/venv/bin/activate && \
uv build --wheel --out-dir /workspace/dist && \
uv build --wheel --out-dir /workspace/dist && \
uv pip install /workspace/dist/
triton_distributed
*cp312*.whl
uv pip install /workspace/dist/
dynemo
*cp312*.whl
# Package the bindings
# Package the bindings
RUN mkdir -p /opt/
triton
/bindings/wheels && \
RUN mkdir -p /opt/
dynemo
/bindings/wheels && \
mkdir /opt/
triton
/bindings/lib && \
mkdir /opt/
dynemo
/bindings/lib && \
cp dist/
triton_distributed
*cp312*.whl /opt/
triton
/bindings/wheels/. && \
cp dist/
dynemo
*cp312*.whl /opt/
dynemo
/bindings/wheels/. && \
cp lib/bindings/c/target/release/lib
triton_distributed
_llm_capi.so /opt/
triton
/bindings/lib/. && \
cp lib/bindings/c/target/release/lib
dynemo
_llm_capi.so /opt/
dynemo
/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/
triton
/bindings/.
cp -r lib/bindings/c/include /opt/
dynemo
/bindings/.
# Tell vllm to use the
Triton
LLM C API for KV Cache Routing
# Tell vllm to use the
Dynemo
LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/
triton
/bindings/lib/lib
triton_distributed
_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/
dynemo
/bindings/lib/lib
dynemo
_llm_capi.so"
# FIXME: Copy more specific folders in for dev/debug after directory restructure
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
COPY . /workspace
# FIXME: May want a modification with
triton
-distributed banner on entry
# FIXME: May want a modification with
dynemo
-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
CMD []
...
@@ -261,10 +261,10 @@ RUN apt update -y && \
...
@@ -261,10 +261,10 @@ RUN apt update -y && \
echo "set -g mouse on" >> /root/.tmux.conf
echo "set -g mouse on" >> /root/.tmux.conf
# Set environment variables
# Set environment variables
ENV VIRTUAL_ENV=/opt/
triton
/venv
ENV VIRTUAL_ENV=/opt/
dynemo
/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV VLLM_KV_CAPI_PATH="/opt/
triton
/bindings/lib/lib
triton_distributed
_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/
dynemo
/bindings/lib/lib
dynemo
_llm_capi.so"
# Copy binaries
# Copy binaries
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
...
@@ -291,7 +291,7 @@ COPY examples/python_rs/llm/vllm_nixl /workspace/examples/python_rs/llm/vllm_nix
...
@@ -291,7 +291,7 @@ COPY examples/python_rs/llm/vllm_nixl /workspace/examples/python_rs/llm/vllm_nix
WORKDIR /workspace
WORKDIR /workspace
# FIXME: May want a modification with
triton
-distributed banner on entry
# FIXME: May want a modification with
dynemo
-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
CMD []
container/deps/clone_tensorrtllm.sh
View file @
1af7433b
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
TENSORRTLLM_BACKEND_REPO_TAG
=
TENSORRTLLM_BACKEND_REPO_TAG
=
TENSORRTLLM_BACKEND_REBUILD
=
TENSORRTLLM_BACKEND_REBUILD
=
TRITON
_LLM_PATH
=
DYNEMO
_LLM_PATH
=
GIT_TOKEN
=
GIT_TOKEN
=
GIT_REPO
=
GIT_REPO
=
...
@@ -43,9 +43,9 @@ get_options() {
...
@@ -43,9 +43,9 @@ get_options() {
missing_requirement
$1
missing_requirement
$1
fi
fi
;;
;;
--
triton
-llm-path
)
--
dynemo
-llm-path
)
if
[
"
$2
"
]
;
then
if
[
"
$2
"
]
;
then
TRITON
_LLM_PATH
=
$2
DYNEMO
_LLM_PATH
=
$2
shift
shift
else
else
missing_requirement
$1
missing_requirement
$1
...
@@ -147,9 +147,9 @@ if [ ! -z ${TENSORRTLLM_BACKEND_REBUILD} ]; then
...
@@ -147,9 +147,9 @@ if [ ! -z ${TENSORRTLLM_BACKEND_REBUILD} ]; then
# Build the backend
# Build the backend
(
cd
inflight_batcher_llm/src
\
(
cd
inflight_batcher_llm/src
\
&&
cmake
-DCMAKE_INSTALL_PREFIX
:PATH
=
`
pwd
`
/install
-DUSE_CXX11_ABI
=
1
-D
TRITON
_LLM_PATH
=
$
TRITON
_LLM_PATH
..
\
&&
cmake
-DCMAKE_INSTALL_PREFIX
:PATH
=
`
pwd
`
/install
-DUSE_CXX11_ABI
=
1
-D
DYNEMO
_LLM_PATH
=
$
DYNEMO
_LLM_PATH
..
\
&&
make
install
\
&&
make
install
\
&&
cp
lib
triton
_tensorrtllm.so /opt/tritonserver/backends/tensorrtllm/
\
&&
cp
lib
dynemo
_tensorrtllm.so /opt/tritonserver/backends/tensorrtllm/
\
&&
cp
trtllmExecutorWorker /opt/tritonserver/backends/tensorrtllm/
\
&&
cp
trtllmExecutorWorker /opt/tritonserver/backends/tensorrtllm/
\
)
)
fi
fi
...
...
container/deps/vllm/vllm_v0.7.2-
triton
-kv-disagg-patch.patch
→
container/deps/vllm/vllm_v0.7.2-
dynemo
-kv-disagg-patch.patch
View file @
1af7433b
...
@@ -31,7 +31,7 @@ index 9ba49757..a2f88854 100644
...
@@ -31,7 +31,7 @@ index 9ba49757..a2f88854 100644
f"and `kv_both`")
f"and `kv_both`")
- if self.kv_connector is not None and self.kv_role is None:
- if self.kv_connector is not None and self.kv_role is None:
+ if self.kv_connector is not None and self.kv_connector != "
Triton
NixlConnector" and self.kv_role is None:
+ if self.kv_connector is not None and self.kv_connector != "
Dynemo
NixlConnector" and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
"`kv_consumer`, and `kv_both`")
...
@@ -44,7 +44,7 @@ index 9ba49757..a2f88854 100644
...
@@ -44,7 +44,7 @@ index 9ba49757..a2f88854 100644
def need_kv_parallel_group(self) -> bool:
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
# parallel group, and in that case the kv parallel size will be 1.
+ if self.kv_connector == "
Triton
NixlConnector":
+ if self.kv_connector == "
Dynemo
NixlConnector":
+ return False
+ return False
return self.kv_connector is not None and self.kv_parallel_size > 1
return self.kv_connector is not None and self.kv_parallel_size > 1
...
@@ -277,7 +277,7 @@ index 00000000..350453cd
...
@@ -277,7 +277,7 @@ index 00000000..350453cd
+logger = logging.getLogger(__name__)
+logger = logging.getLogger(__name__)
+
+
+
+
+class
Triton
Result:
+class
Dynemo
Result:
+ OK = 0
+ OK = 0
+ ERR = 1
+ ERR = 1
+
+
...
@@ -290,12 +290,12 @@ index 00000000..350453cd
...
@@ -290,12 +290,12 @@ index 00000000..350453cd
+
+
+ try:
+ try:
+ self.lib = ctypes.CDLL(lib_path)
+ self.lib = ctypes.CDLL(lib_path)
+ self.lib.
triton
_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.
dynemo
_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.
triton
_llm_init.restype = c_uint32
+ self.lib.
dynemo
_llm_init.restype = c_uint32
+
+
+ result = self.lib.
triton
_llm_init(namespace.encode(),
+ result = self.lib.
dynemo
_llm_init(namespace.encode(),
+ component.encode(), worker_id)
+ component.encode(), worker_id)
+ if result ==
Triton
Result.OK:
+ if result ==
Dynemo
Result.OK:
+ logger.info(
+ logger.info(
+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
+ )
+ )
...
@@ -306,7 +306,7 @@ index 00000000..350453cd
...
@@ -306,7 +306,7 @@ index 00000000..350453cd
+ print(f"Failed to load {lib_path}")
+ print(f"Failed to load {lib_path}")
+ raise e
+ raise e
+
+
+ self.lib.
triton
_kv_event_publish_stored.argtypes = [
+ self.lib.
dynemo
_kv_event_publish_stored.argtypes = [
+ ctypes.c_uint64, # event_id
+ ctypes.c_uint64, # event_id
+ ctypes.POINTER(ctypes.c_uint32), # token_ids
+ ctypes.POINTER(ctypes.c_uint32), # token_ids
+ ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
+ ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
...
@@ -315,14 +315,14 @@ index 00000000..350453cd
...
@@ -315,14 +315,14 @@ index 00000000..350453cd
+ ctypes.POINTER(ctypes.c_uint64), # parent_hash
+ ctypes.POINTER(ctypes.c_uint64), # parent_hash
+ ctypes.c_uint64, # lora_id
+ ctypes.c_uint64, # lora_id
+ ]
+ ]
+ self.lib.
triton
_kv_event_publish_stored.restype = ctypes.c_uint32 #
triton
_llm_result_t
+ self.lib.
dynemo
_kv_event_publish_stored.restype = ctypes.c_uint32 #
dynemo
_llm_result_t
+
+
+ self.lib.
triton
_kv_event_publish_removed.argtypes = [
+ self.lib.
dynemo
_kv_event_publish_removed.argtypes = [
+ ctypes.c_uint64, # event_id
+ ctypes.c_uint64, # event_id
+ ctypes.POINTER(ctypes.c_uint64), # block_ids
+ ctypes.POINTER(ctypes.c_uint64), # block_ids
+ ctypes.c_size_t, # num_blocks
+ ctypes.c_size_t, # num_blocks
+ ]
+ ]
+ self.lib.
triton
_kv_event_publish_removed.restype = ctypes.c_uint32 #
triton
_llm_result_t
+ self.lib.
dynemo
_kv_event_publish_removed.restype = ctypes.c_uint32 #
dynemo
_llm_result_t
+
+
+ self.event_id_counter = 0
+ self.event_id_counter = 0
+
+
...
@@ -336,7 +336,7 @@ index 00000000..350453cd
...
@@ -336,7 +336,7 @@ index 00000000..350453cd
+ if parent is not None else None)
+ if parent is not None else None)
+
+
+ # Publish the event
+ # Publish the event
+ result = self.lib.
triton
_kv_event_publish_stored(
+ result = self.lib.
dynemo
_kv_event_publish_stored(
+ self.event_id_counter, # uint64_t event_id
+ self.event_id_counter, # uint64_t event_id
+ token_ids_arr, # const uint32_t *token_ids
+ token_ids_arr, # const uint32_t *token_ids
+ num_block_tokens, # const uintptr_t *num_block_tokens
+ num_block_tokens, # const uintptr_t *num_block_tokens
...
@@ -346,7 +346,7 @@ index 00000000..350453cd
...
@@ -346,7 +346,7 @@ index 00000000..350453cd
+ 0, # uint64_t lora_id
+ 0, # uint64_t lora_id
+ )
+ )
+
+
+ if result ==
Triton
Result.OK:
+ if result ==
Dynemo
Result.OK:
+ logger.debug(f"Store - Published KV Event: {block.content_hash}")
+ logger.debug(f"Store - Published KV Event: {block.content_hash}")
+ else:
+ else:
+ logger.debug(
+ logger.debug(
...
@@ -355,28 +355,23 @@ index 00000000..350453cd
...
@@ -355,28 +355,23 @@ index 00000000..350453cd
+ self.event_id_counter += 1
+ self.event_id_counter += 1
+
+
+ def enqueue_removed_event(self, block_hash: PrefixHash):
+ def enqueue_removed_event(self, block_hash: PrefixHash):
+ result = self.lib.
triton
_kv_event_publish_removed(
+ result = self.lib.
dynemo
_kv_event_publish_removed(
+ self.event_id_counter,
+ self.event_id_counter,
+ (ctypes.c_uint64 * 1)(block_hash),
+ (ctypes.c_uint64 * 1)(block_hash),
+ 1,
+ 1,
+ )
+ )
+
+
+ if result ==
Triton
Result.OK:
+ if result ==
Dynemo
Result.OK:
+ logger.debug(f"Remove - Published KV Event: {block_hash}")
+ logger.debug(f"Remove - Published KV Event: {block_hash}")
+ else:
+ else:
+ logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
+ logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
+
+
+ self.event_id_counter += 1
+ self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..
abe574d1
100644
index f507847a..
ee20d50c
100644
--- a/vllm/core/scheduler.py
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -4,22 +4,22 @@
import enum
@@ -8,18 +8,17 @@
from collections import deque
import os
import random
import time
+import copy
from collections import deque
from dataclasses import dataclass, field
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Sequence as GenericSequence
...
@@ -398,7 +393,7 @@ index f507847a..abe574d1 100644
...
@@ -398,7 +393,7 @@ index f507847a..abe574d1 100644
logger = init_logger(__name__)
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
# Test-only. If configured, decode is preempted with
@@ -325,12 +32
5
,14 @@
class Scheduler:
@@ -325,12 +32
4
,14 @@
class Scheduler:
def __init__(
def __init__(
self,
self,
...
@@ -413,7 +408,7 @@ index f507847a..abe574d1 100644
...
@@ -413,7 +408,7 @@ index f507847a..abe574d1 100644
self.scheduler_config = scheduler_config
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
# Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +35
8
,7 @@
class Scheduler:
@@ -356,6 +35
7
,7 @@
class Scheduler:
# Create the block space manager.
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
self.block_manager = BlockSpaceManagerImpl(
...
@@ -421,7 +416,7 @@ index f507847a..abe574d1 100644
...
@@ -421,7 +416,7 @@ index f507847a..abe574d1 100644
block_size=self.cache_config.block_size,
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +37
4
,1
6
@@
class Scheduler:
@@ -371,6 +37
3
,1
4
@@
class Scheduler:
# Sequence groups in the SWAPPED state.
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
self.swapped: Deque[SequenceGroup] = deque()
...
@@ -429,8 +424,6 @@ index f507847a..abe574d1 100644
...
@@ -429,8 +424,6 @@ index f507847a..abe574d1 100644
+ # Sequence groups in the REMOTE_PREFILLING state.
+ # Sequence groups in the REMOTE_PREFILLING state.
+ # Contain requests that are being prefilled by a remote worker.
+ # Contain requests that are being prefilled by a remote worker.
+ self.remote_prefilling: Deque[SequenceGroup] = deque()
+ self.remote_prefilling: Deque[SequenceGroup] = deque()
+ # Contain requests that are being prefilled by a local worker.
+ self.prefill_sending: Deque[SequenceGroup] = deque()
+
+
+ self._remote_prefill_outputs: Dict[str, int] = {}
+ self._remote_prefill_outputs: Dict[str, int] = {}
+
+
...
@@ -438,25 +431,24 @@ index f507847a..abe574d1 100644
...
@@ -438,25 +431,24 @@ index f507847a..abe574d1 100644
# Sequence groups finished requests ids since last step iteration.
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
# can and must be released after the current step.
@@ -501,7 +51
4
,7 @@
class Scheduler:
@@ -501,7 +51
1
,7 @@
class Scheduler:
def has_unfinished_seqs(self) -> bool:
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
return len(self.waiting) != 0 or len(self.running) != 0 or len(
- self.swapped) != 0
- self.swapped) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0
or len(self.prefill_sending) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +53
6,8
@@
class Scheduler:
@@ -523,6 +53
3,7
@@
class Scheduler:
budget: SchedulingBudget,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
enable_chunking: bool = False,
+ finished_prefills: Optional[Set[str]] = None,
+ finished_prefills: Optional[Set[str]] = None
+ finished_transfers: Optional[Set[str]] = None
) -> SchedulerRunningOutputs:
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
"""Schedule sequence groups that are running.
@@ -537,6 +5
52
,8 @@
class Scheduler:
@@ -537,6 +5
48
,8 @@
class Scheduler:
chunked number of tokens are scheduled if
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
all tokens.
...
@@ -465,7 +457,7 @@ index f507847a..abe574d1 100644
...
@@ -465,7 +457,7 @@ index f507847a..abe574d1 100644
Returns:
Returns:
SchedulerRunningOutputs.
SchedulerRunningOutputs.
@@ -566,6 +5
83,38
@@
class Scheduler:
@@ -566,6 +5
79,24
@@
class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
swapped_out: List[SequenceGroup] = ret.swapped_out
...
@@ -476,7 +468,6 @@ index f507847a..abe574d1 100644
...
@@ -476,7 +468,6 @@ index f507847a..abe574d1 100644
+ if seq_group.request_id not in finished_prefills:
+ if seq_group.request_id not in finished_prefills:
+ leftover_remote_prefilling_sequences.append(seq_group)
+ leftover_remote_prefilling_sequences.append(seq_group)
+ continue
+ continue
+
+ else:
+ else:
+ finished_prefills.remove(seq_group.request_id)
+ finished_prefills.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ assert len(seq_group.seqs) == 1
...
@@ -487,63 +478,39 @@ index f507847a..abe574d1 100644
...
@@ -487,63 +478,39 @@ index f507847a..abe574d1 100644
+ seq.data._stage = SequenceStage.DECODE
+ seq.data._stage = SequenceStage.DECODE
+ self.running.appendleft(seq_group)
+ self.running.appendleft(seq_group)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+
+ remote_transfers_queue = self.prefill_sending
+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
+ while remote_transfers_queue:
+ seq_group = remote_transfers_queue.popleft()
+ if seq_group.request_id not in finished_transfers:
+ leftover_remote_transfers_sequences.append(seq_group)
+ else:
+ finished_transfers.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ seq = seq_group.seqs[0]
+ self.free_seq(seq)
+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
+
+
running_queue = self.running
running_queue = self.running
assert len(self._async_stopped) == 0
assert len(self._async_stopped) == 0
while running_queue:
while running_queue:
@@ -1008,7 +10
57,1
7 @@
class Scheduler:
@@ -1008,7 +10
39,
7 @@
class Scheduler:
if curr_loras is not None and lora_int_id > 0:
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
curr_loras.add(lora_int_id)
waiting_queue.popleft()
waiting_queue.popleft()
- self._allocate_and_set_running(seq_group)
- self._allocate_and_set_running(seq_group)
+
+ seq_group_copy = copy.deepcopy(seq_group)
+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
+
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group)
+ self._allocate_and_set_running_or_remote_prefill(seq_group)
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
+ self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step:
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +
1
107,7 @@
class Scheduler:
@@ -1048,7 +107
9
,7 @@
class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
is_prefill=True, enable_chunking=enable_chunking))
- def _schedule_default(self) -> SchedulerOutputs:
- def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] =
None, finished_transfers: Optional[Set[str]] =
None) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
The current policy is designed to optimize the throughput. First,
@@ -1090,7 +11
49,9
@@
class Scheduler:
@@ -1090,7 +11
21,8
@@
class Scheduler:
if len(prefills.seq_groups) == 0:
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
running_scheduled = self._schedule_running(budget,
curr_loras,
curr_loras,
- enable_chunking=False)
- enable_chunking=False)
+ enable_chunking=False,
+ enable_chunking=False,
+ finished_prefills=finished_prefills,
+ finished_prefills=finished_prefills)
+ finished_transfers=finished_transfers)
# If any sequence group is preempted, do not swap in any sequence
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
# group. because it means there's no slot for new running requests.
@@ -1106,7 +11
67
,12 @@
class Scheduler:
@@ -1106,7 +11
38
,12 @@
class Scheduler:
self.waiting.extendleft(running_scheduled.preempted)
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
# Update new running requests.
if len(prefills.seq_groups) > 0:
if len(prefills.seq_groups) > 0:
...
@@ -557,31 +524,30 @@ index f507847a..abe574d1 100644
...
@@ -557,31 +524,30 @@ index f507847a..abe574d1 100644
self.running.extend(running_scheduled.decode_seq_groups_list)
self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1
314
,14 @@
class Scheduler:
@@ -1248,12 +1
285
,14 @@
class Scheduler:
len(running_scheduled.swapped_out)),
len(running_scheduled.swapped_out)),
)
)
- def _schedule(self) -> SchedulerOutputs:
- def _schedule(self) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] =
None, finished_transfers: Optional[Set[str]] =
None) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests."""
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
if self.scheduler_config.chunked_prefill_enabled:
+ if finished_prefills
or finished_transfers
:
+ if finished_prefills:
+ raise ValueError("Chunked prefill does not support remote prefills")
+ raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill()
return self._schedule_chunked_prefill()
else:
else:
- return self._schedule_default()
- return self._schedule_default()
+ return self._schedule_default(finished_prefills
, finished_transfers
)
+ return self._schedule_default(finished_prefills)
def _can_append_slots(self, seq_group: SequenceGroup,
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
enable_chunking: bool) -> bool:
@@ -1287,14 +13
55
,1
6
@@
class Scheduler:
@@ -1287,14 +13
26
,1
5
@@
class Scheduler:
return no_single_seq
return no_single_seq
def schedule(
def schedule(
- self
- self
+ self,
+ self,
+ finished_prefills: Optional[Set[str]] = None,
+ finished_prefills: Optional[Set[str]] = None
+ finished_transfers: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# This function call changes the internal states of the scheduler
...
@@ -590,11 +556,11 @@ index f507847a..abe574d1 100644
...
@@ -590,11 +556,11 @@ index f507847a..abe574d1 100644
- scheduler_outputs: SchedulerOutputs = self._schedule()
- scheduler_outputs: SchedulerOutputs = self._schedule()
+ scheduler_start_time = time.perf_counter()
+ scheduler_start_time = time.perf_counter()
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills
, finished_transfers
)
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills)
now = time.time()
now = time.time()
if not self.cache_config.enable_prefix_caching:
if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1
40
3,8 @@
class Scheduler:
@@ -1333,7 +1
37
3,8 @@
class Scheduler:
encoder_seq_data = None
encoder_seq_data = None
cross_block_table = None
cross_block_table = None
...
@@ -604,24 +570,18 @@ index f507847a..abe574d1 100644
...
@@ -604,24 +570,18 @@ index f507847a..abe574d1 100644
seq_id = seq.seq_id
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,
9
+14
3
5,1
6
@@
class Scheduler:
@@ -1364,
6
+14
0
5,1
0
@@
class Scheduler:
< seqs[0].data.get_len()):
< seqs[0].data.get_len()):
do_sample = False
do_sample = False
+ is_remote_prefill = False
+ is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True
+ is_remote_prefill = True
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
+
+
# It assumes the scheduled_seq_groups is ordered by
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
if is_first_prefill or not self.scheduler_config.send_delta_data:
+ logger.debug("Assinged blocks: %s", block_tables)
@@ -1392,6 +1437,7 @@
class Scheduler:
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@@ -1392,6 +1470,7 @@
class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
prompt_adapter_request=seq_group.prompt_adapter_request,
...
@@ -629,7 +589,7 @@ index f507847a..abe574d1 100644
...
@@ -629,7 +589,7 @@ index f507847a..abe574d1 100644
)
)
else:
else:
# When SPMD mode is enabled, we only send delta data except for
# When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +156
9
,13 @@
class Scheduler:
@@ -1490,10 +15
3
6,13 @@
class Scheduler:
self._async_stopped.clear()
self._async_stopped.clear()
...
@@ -645,80 +605,12 @@ index f507847a..abe574d1 100644
...
@@ -645,80 +605,12 @@ index f507847a..abe574d1 100644
def _append_slots(self,
def _append_slots(self,
seq_group: SequenceGroup,
seq_group: SequenceGroup,
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644
index 00000000..9b938039
--- /dev/null
+++ b/vllm/distributed/device_communicators/kv_rearrange.py
@@ -0,0 +1,61 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def rearrange_kernel(
+ t1_ptr,
+ t2_ptr,
+ N,
+ B,
+ H,
+ C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+
+ curr_n = offsets // block_size
+ curr_b = offsets // token_size % B
+ curr_h = offsets // C % H
+ curr_c = offsets % C
+
+ src_pos = offsets
+
+ tp_group = curr_h * d // H
+ dst_h = curr_h % (H // d)
+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
+
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset
+
+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
+
+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int):
+ N, B, H, C = t1.shape
+
+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
+ assert H % d == 0, "H must be divisible by d"
+
+ block_size = B * H * C
+ token_size = H * C
+ tensor_size = N * block_size
+ tensor_subset_size = tensor_size // d
+
+ BLOCK_SIZE = 1024
+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
+
+ rearrange_kernel[grid](
+ t1, t2,
+ N, B, H, C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE=BLOCK_SIZE
+ )
\
No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
new file mode 100644
index 00000000..
f1618bc4
index 00000000..
bc962726
--- /dev/null
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,
318
@@
@@ -0,0 +1,
249
@@
+import torch
+import torch
+from typing import List, Tuple
+from typing import List, Tuple
+from vllm.config import VllmConfig
+from vllm.config import VllmConfig
...
@@ -726,18 +618,39 @@ index 00000000..f1618bc4
...
@@ -726,18 +618,39 @@ index 00000000..f1618bc4
+import msgspec
+import msgspec
+import time
+import time
+import uuid
+import uuid
+from collections import defaultdict
+from nixl_wrapper import nixl_wrapper as NixlWrapper
+from .kv_rearrange import rearrange_tensors
+
+
+logger = init_logger(__name__)
+logger = init_logger(__name__)
+
+
+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
+
+try:
+def nixl_wrapper_init_patch(self, agent_name, nixl_config):
+ from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore
+ logger.info("Initializing patched NixlWrapper")
+ logger.info("NIXL is available")
+ import nixl_bindings as nixl
+except ImportError:
+ # Read available backends and device info from nixl_config
+ logger.warning("NIXL is not available")
+ # For now setting the multithreading to enabled.
+ NixlWrapper = None
+ devices = nixl.nixlAgentConfig(False)
+ init = nixl.nixlUcxInitParams()
+
+ self.name = agent_name
+ self.notifs = {}
+ self.backends = {}
+ self.agent = nixl.nixlAgent(agent_name, devices)
+ self.backends["UCX"] = self.agent.createBackend(init)
+
+ self.nixl_mems = {"DRAM": nixl.DRAM_SEG,
+ "VRAM": nixl.VRAM_SEG,
+ "cpu": nixl.DRAM_SEG,
+ "cuda": nixl.VRAM_SEG}
+ self.nixl_ops = {"WRITE": nixl.NIXL_WR_FLUSH,
+ "READ": nixl.NIXL_RD_FLUSH,
+ "WRITE_NOTIF": nixl.NIXL_WR_NOTIF,
+ "READ_NOTIF": nixl.NIXL_RD_NOTIF}
+
+ print("Initializied NIXL agent:", agent_name)
+
+NixlWrapper.__init__ = nixl_wrapper_init_patch
+
+
+
+
+class NixlMetadata(
+class NixlMetadata(
+ msgspec.Struct,
+ msgspec.Struct,
...
@@ -749,20 +662,14 @@ index 00000000..f1618bc4
...
@@ -749,20 +662,14 @@ index 00000000..f1618bc4
+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
+
+
+
+
+class
Triton
NixlConnector:
+class
Dynemo
NixlConnector:
+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
+ self.vllm_config = vllm_config
+ self.vllm_config = vllm_config
+ if NixlWrapper is None:
+ logger.error("NIXL is not available")
+ raise RuntimeError("NIXL is not available")
+ logger.info("Initializing NIXL wrapper")
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+
+
+ self.num_layers = None
+ self.num_layers = None
+ self.num_blocks = None
+ self.num_blocks = None
+ self.num_heads = None
+ self.block_len = None
+ self.block_len = None
+ self.kv_caches = None
+ self.kv_caches_base_addr = {}
+ self.kv_caches_base_addr = {}
+ self.kv_cache_shape = {}
+ self.kv_cache_shape = {}
+
+
...
@@ -771,51 +678,33 @@ index 00000000..f1618bc4
...
@@ -771,51 +678,33 @@ index 00000000..f1618bc4
+ self.engine_id = engine_id
+ self.engine_id = engine_id
+ self.rank = rank
+ self.rank = rank
+ self.notifs = {}
+ self.notifs = {}
+ self._tp_size = {}
+ self._block_descs = {}
+ self._xfer_side_handles = {}
+
+
+ self._transfers = defaultdict(list)
+
+
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+
+
+
+ @property
+ @property
+ def agent_name(self):
+ def agent_name(self):
+ return self.nixl_wrapper.name
+ return self.nixl_wrapper.name
+
+
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ caches_data = []
+ self.num_layers = len(kv_caches)
+ _, _, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ self.num_layers = len(kv_caches)
+
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+ kv_caches_base_addr = []
+ kv_caches_base_addr = []
+ caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches:
+ for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]:
+ for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr()
+ base_addr = cache.data_ptr()
+ region_len = num_blocks * self.block_len
+ region_len = cache.numel() * cache.element_size()
+ caches_data.append((base_addr, region_len, self.rank))
+ gpu_id = cache.get_device()
+ for block_id in range(self.num_blocks):
+ assert gpu_id > -1, "Tensor is not on GPU"
+ blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+ caches_data.append((base_addr, region_len, gpu_id))
+
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+ self._registered_descs.append(descs)
+
+
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self):
+ def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata()
+ return self.nixl_wrapper.get_agent_metadata()
+
+
...
@@ -825,14 +714,10 @@ index 00000000..f1618bc4
...
@@ -825,14 +714,10 @@ index 00000000..f1618bc4
+ for agent_name in self._remote_agents.values():
+ for agent_name in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+
+
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp):
+ def add_remote_agent(self, engine_id, agent_metadata):
+ self._tp_size[engine_id] = agent_tp
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_metadata)
+ agent_names = []
+ self._remote_agents[engine_id] = agent_name
+ for agent_meta in agent_metadata:
+ return agent_name
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ if layer_ids == "all":
...
@@ -847,29 +732,17 @@ index 00000000..f1618bc4
...
@@ -847,29 +732,17 @@ index 00000000..f1618bc4
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids
+ return descs_ids
+
+
+ def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0):
+ def _get_range_descs(self, engine_id, ranges, layer_ids):
+ if rank is None:
+ rank = self.rank
+ offset_block_len = self.block_len
+ block_len = self.block_len // tp_multiplier
+ tp_offset = i * block_len
+ else:
+ offset_block_len = self.block_len // tp_multiplier
+ block_len = self.block_len // tp_multiplier
+ tp_offset = 0
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ if layer_ids == "all":
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ layer_ids = list(range(self.num_layers))
+ blocks_data = []
+ blocks_data = []
+ for layer_id in layer_ids:
+ for layer_id in layer_ids:
+ for range_start, range_end in ranges:
+ for range_start, range_end in ranges:
+ range_len = range_end - range_start + 1
+ key_base_addr, value_base_addr = self.kv_caches_base_addr[engine_id][layer_id]
+ key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ start_offset = range_start * self.block_len
+ start_offset = range_start * offset_block_len + tp_offset * range_len
+ blocks_len = (range_end - range_start + 1) * self.block_len
+ blocks_len = range_len * block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, self.rank))
+ blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, self.rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+
+
+ def _get_ranges(self, block_ids):
+ def _get_ranges(self, block_ids):
...
@@ -882,9 +755,9 @@ index 00000000..f1618bc4
...
@@ -882,9 +755,9 @@ index 00000000..f1618bc4
+ ranges = []
+ ranges = []
+ for i in range(len(sorted_block_ids)):
+ for i in range(len(sorted_block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ ranges.append([
sorted_block_ids[i],
sorted_block_ids[i]])
+ ranges.append([sorted_block_ids[i]])
+ else:
+ else:
+ ranges[-1]
[1] =
sorted_block_ids[i]
+ ranges[-1]
.append(
sorted_block_ids[i]
)
+ return ranges
+ return ranges
+
+
+ def _get_same_length_ranges(self, src_ranges, dst_ranges):
+ def _get_same_length_ranges(self, src_ranges, dst_ranges):
...
@@ -927,21 +800,8 @@ index 00000000..f1618bc4
...
@@ -927,21 +800,8 @@ index 00000000..f1618bc4
+
+
+
+
+
+
+ def _get_block_descs_ids(self, layer_ids, block_ids):
+ def transfer_mem(self, src_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id)
+ return descs_ids
+
+
+
+
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg, use_prepped_xfer=False):
+ start_time = time.perf_counter()
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
+
...
@@ -950,61 +810,43 @@ index 00000000..f1618bc4
...
@@ -950,61 +810,43 @@ index 00000000..f1618bc4
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+
+ if use_prepped_xfer:
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id]
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id]
+
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000)
+
+
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True)
+ # else:
+ # Legacy path using range-based transfers
+ src_ranges = self._get_ranges(src_block_ids)
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
+ src_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(src_ranges, dst_ranges)
+
+
+ assert len(src_ranges) == 1
+ logger.debug("Got %s overlapping ranges for %s blocks", len(src_overlapping_ranges), len(src_block_ids))
+ assert len(staging_ranges) == 1
+
+
+
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
logger.debug("Time to get ranges: %s ms", time.perf_counter() - start_time)
+
+
+ src_range_start, src_range_end = src_ranges[0]
+ src_descs = self._get_range_descs(self.engine_id, src_overlapping_ranges, "all")
+ src_range_len = src_range_end - src_range_start + 1
+ dst_descs = self._get_range_descs(dst_engine_id, dst_overlapping_ranges, "all")
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+
+
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+
+ staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+
+ for i in range(tp_multiplier):
+
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs, self._remote_agents[dst_engine_id], notify_msg, "WRITE")
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg, "WRITE")
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ logger.debug("Transfer status: %s", status)
+ # TODO ptarasiewicz: remove blocking transfer mem
+ # add scheduler check for transfer done
+ while True:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "ERR":
+ raise RuntimeError("Transfer failed")
+ elif xfer_state == "DONE":
+ logger.debug("Transfer done")
+ break
+ elif xfer_state == "PROC":
+ time.sleep(0.01)
+ else:
+ raise RuntimeError("Unknown transfer state")
+ logger.debug("Time to wait for transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ self.nixl_wrapper.abort_xfer(handle)
+ logger.debug("Time to abort xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer time: %s ms", (time.perf_counter() - start_time) * 1000)
+
+
+ def deserialize_descs(self, serialized_descs):
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
...
@@ -1018,26 +860,6 @@ index 00000000..f1618bc4
...
@@ -1018,26 +860,6 @@ index 00000000..f1618bc4
+
+
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
+
+ def get_done_tranfers(self) -> List[str]:
+ done_req_ids = []
+ for req_id, handles in self._transfers.items():
+ running_reqs = []
+ for handle in handles:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "DONE":
+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ continue
+ if xfer_state == "PROC":
+ running_reqs.append(handle)
+ else:
+ raise RuntimeError("Transfer failed with state %s", xfer_state)
+ if len(running_reqs) == 0:
+ done_req_ids.append(req_id)
+ else:
+ self._transfers[req_id] = running_reqs
+ return done_req_ids
\
No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644
index fe480533..61a357d0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
...
@@ -1064,9 +886,9 @@ index fe480533..61a357d0 100644
...
@@ -1064,9 +886,9 @@ index fe480533..61a357d0 100644
"SimpleConnector")
"SimpleConnector")
+
+
+KVConnectorFactory.register_connector(
+KVConnectorFactory.register_connector(
+ "
Triton
NcclConnector",
+ "
Dynemo
NcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.
triton
_connector",
+ "vllm.distributed.kv_transfer.kv_connector.
dynemo
_connector",
+ "
Triton
Connector")
+ "
Dynemo
Connector")
\
No newline at end of file
\
No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..e33919c1 100644
index 2033e976..e33919c1 100644
...
@@ -1396,7 +1218,7 @@ index 2033e976..e33919c1 100644
...
@@ -1396,7 +1218,7 @@ index 2033e976..e33919c1 100644
+ world_group.broadcast_object(kv_config_enhanced)
+ world_group.broadcast_object(kv_config_enhanced)
+
+
+ else:
+ else:
+ raise NotImplementedError("MooncakeConnector is not supported in
Triton
Distributed vllm patch")
+ raise NotImplementedError("MooncakeConnector is not supported in
Dynemo
Distributed vllm patch")
+ else:
+ else:
+ kv_config_enhanced = world_group.broadcast_object()
+ kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
...
@@ -1407,11 +1229,11 @@ index 2033e976..e33919c1 100644
...
@@ -1407,11 +1229,11 @@ index 2033e976..e33919c1 100644
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\
No newline at end of file
\
No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/
triton
_connector.py b/vllm/distributed/kv_transfer/kv_connector/
triton
_connector.py
diff --git a/vllm/distributed/kv_transfer/kv_connector/
dynemo
_connector.py b/vllm/distributed/kv_transfer/kv_connector/
dynemo
_connector.py
new file mode 100644
new file mode 100644
index 00000000..cb3b3660
index 00000000..cb3b3660
--- /dev/null
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/
triton
_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/
dynemo
_connector.py
@@ -0,0 +1,350 @@
@@ -0,0 +1,350 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-License-Identifier: Apache-2.0
+"""
+"""
...
@@ -1443,7 +1265,7 @@ index 00000000..cb3b3660
...
@@ -1443,7 +1265,7 @@ index 00000000..cb3b3660
+logger = init_logger(__name__)
+logger = init_logger(__name__)
+
+
+
+
+class
Triton
Connector(KVConnectorBase):
+class
Dynemo
Connector(KVConnectorBase):
+
+
+ def __init__(
+ def __init__(
+ self,
+ self,
...
@@ -1457,16 +1279,16 @@ index 00000000..cb3b3660
...
@@ -1457,16 +1279,16 @@ index 00000000..cb3b3660
+ self.tp_size = config.parallel_config.tensor_parallel_size
+ self.tp_size = config.parallel_config.tensor_parallel_size
+ self.rank = rank
+ self.rank = rank
+
+
+ if self.config.kv_connector != "
Triton
NcclConnector":
+ if self.config.kv_connector != "
Dynemo
NcclConnector":
+ raise NotImplementedError("Only
Triton
NcclConnector is supported by the
Triton
Connector class")
+ raise NotImplementedError("Only
Dynemo
NcclConnector is supported by the
Dynemo
Connector class")
+
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.
triton
_nccl_pipe import (
+ from vllm.distributed.kv_transfer.kv_pipe.
dynemo
_nccl_pipe import (
+
Triton
NcclDataPlane)
+
Dynemo
NcclDataPlane)
+
+
+ logger.info(
+ logger.info(
+ "Initializing
Triton
NcclConnector under kv_transfer_config %s",
+ "Initializing
Dynemo
NcclConnector under kv_transfer_config %s",
+ self.config)
+ self.config)
+
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
+ self.lookup_buffer_size = self.config.kv_buffer_size
...
@@ -1498,7 +1320,7 @@ index 00000000..cb3b3660
...
@@ -1498,7 +1320,7 @@ index 00000000..cb3b3660
+ port_offset=port_offset_base,
+ port_offset=port_offset_base,
+ )
+ )
+
+
+ self.data_plane =
Triton
NcclDataPlane(
+ self.data_plane =
Dynemo
NcclDataPlane(
+ data_pipe=self.data_pipe,
+ data_pipe=self.data_pipe,
+ port=self._get_data_plane_port(self.global_kv_rank),
+ port=self._get_data_plane_port(self.global_kv_rank),
+ )
+ )
...
@@ -2233,11 +2055,11 @@ index 7aa53d07..f5dd50b7 100644
...
@@ -2233,11 +2055,11 @@ index 7aa53d07..f5dd50b7 100644
def close(self):
def close(self):
"""
"""
diff --git a/vllm/distributed/kv_transfer/kv_pipe/
triton
_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/
triton
_nccl_pipe.py
diff --git a/vllm/distributed/kv_transfer/kv_pipe/
dynemo
_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/
dynemo
_nccl_pipe.py
new file mode 100644
new file mode 100644
index 00000000..8a356504
index 00000000..8a356504
--- /dev/null
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/
triton
_nccl_pipe.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/
dynemo
_nccl_pipe.py
@@ -0,0 +1,124 @@
@@ -0,0 +1,124 @@
+import logging
+import logging
+import threading
+import threading
...
@@ -2253,7 +2075,7 @@ index 00000000..8a356504
...
@@ -2253,7 +2075,7 @@ index 00000000..8a356504
+logger = logging.getLogger(__name__)
+logger = logging.getLogger(__name__)
+
+
+
+
+class
Triton
NcclDataPlane:
+class
Dynemo
NcclDataPlane:
+ def __init__(
+ def __init__(
+ self,
+ self,
+ data_pipe: PyNcclPipe,
+ data_pipe: PyNcclPipe,
...
@@ -2399,7 +2221,7 @@ index 321902d1..b8937ef8 100644
...
@@ -2399,7 +2221,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized(
def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..
254337cb
100644
index d82d9ad9..
9ba1a326
100644
--- a/vllm/engine/llm_engine.py
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@
@@ -2,13 +2,17 @@
...
@@ -2482,13 +2304,11 @@ index d82d9ad9..254337cb 100644
...
@@ -2482,13 +2304,11 @@ index d82d9ad9..254337cb 100644
+ self.engine_id = str(uuid.uuid4())
+ self.engine_id = str(uuid.uuid4())
+ self._nixl_agents_names: Optional[List[str]] = None
+ self._nixl_agents_names: Optional[List[str]] = None
+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "
Triton
NixlConnector":
+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "
Dynemo
NixlConnector":
+ self._nixl_agents_names = self._initialize_nixl()
+ self._nixl_agents_names = self._initialize_nixl()
+
+
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._finished_prefills = set()
+ self._finished_prefills = set()
+ self._finished_transfers = set()
+
+
+ @property
+ @property
+ def is_nixl_initialized(self) -> bool:
+ def is_nixl_initialized(self) -> bool:
...
@@ -2507,6 +2327,8 @@ index d82d9ad9..254337cb 100644
...
@@ -2507,6 +2327,8 @@ index d82d9ad9..254337cb 100644
+ engine_id = nixl_metadata.engine_id
+ engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata
+ agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ if len(agents_metadata) != len(self._nixl_agents_names):
+ raise ValueError("Number of agents does not match. Make sure all engines are initialized with the same parallel sizes.")
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+
+
+ def _initialize_nixl(self) -> List[bytes]:
+ def _initialize_nixl(self) -> List[bytes]:
...
@@ -2540,16 +2362,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2540,16 +2362,7 @@ index d82d9ad9..254337cb 100644
ParallelSampleSequenceGroup.add_request(
ParallelSampleSequenceGroup.add_request(
request_id,
request_id,
self,
self,
@@ -574,6 +624,8 @@
class LLMEngine:
@@ -584,7 +634,7 @@
class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
+ next(self.seq_counter) # empty sequence for staging
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +636,7 @@
class LLMEngine:
encoder_inputs = None
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
...
@@ -2558,7 +2371,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2558,7 +2371,7 @@ index d82d9ad9..254337cb 100644
encoder_seq = (None if encoder_inputs is None else Sequence(
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@@ -601,8 +65
3
,12 @@
class LLMEngine:
@@ -601,8 +65
1
,12 @@
class LLMEngine:
trace_headers=trace_headers,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
encoder_seq=encoder_seq,
...
@@ -2572,7 +2385,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2572,7 +2385,7 @@ index d82d9ad9..254337cb 100644
seq_group = self._create_sequence_group_with_pooling(
seq_group = self._create_sequence_group_with_pooling(
request_id,
request_id,
seq,
seq,
@@ -673,6 +72
9
,7 @@
class LLMEngine:
@@ -673,6 +72
7
,7 @@
class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
priority: int = 0,
...
@@ -2580,7 +2393,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2580,7 +2393,7 @@ index d82d9ad9..254337cb 100644
*,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
) -> None:
@@ -765,6 +82
2
,7 @@
class LLMEngine:
@@ -765,6 +82
0
,7 @@
class LLMEngine:
prompt_adapter_request=prompt_adapter_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
trace_headers=trace_headers,
priority=priority,
priority=priority,
...
@@ -2588,7 +2401,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2588,7 +2401,7 @@ index d82d9ad9..254337cb 100644
)
)
def _validate_token_prompt(self, prompt: PromptType,
def _validate_token_prompt(self, prompt: PromptType,
@@ -799,6 +85
7
,7 @@
class LLMEngine:
@@ -799,6 +85
5
,7 @@
class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
priority: int = 0,
...
@@ -2596,7 +2409,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2596,7 +2409,7 @@ index d82d9ad9..254337cb 100644
) -> SequenceGroup:
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
max_logprobs = self.get_model_config().max_logprobs
@@ -829,7 +88
8
,9 @@
class LLMEngine:
@@ -829,7 +88
6
,9 @@
class LLMEngine:
trace_headers=trace_headers,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
encoder_seq=encoder_seq,
...
@@ -2607,7 +2420,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2607,7 +2420,7 @@ index d82d9ad9..254337cb 100644
return seq_group
return seq_group
@@ -995,11 +105
6
,11 @@
class LLMEngine:
@@ -995,11 +105
4
,11 @@
class LLMEngine:
# When we process only one request, no pop is required
# When we process only one request, no pop is required
# (since later we will process all of the rest)
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
...
@@ -2621,7 +2434,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2621,7 +2434,7 @@ index d82d9ad9..254337cb 100644
# Sanity check
# Sanity check
assert len(seq_group_metadata_list) == len(
assert len(seq_group_metadata_list) == len(
@@ -1325,15 +138
6
,49 @@
class LLMEngine:
@@ -1325,15 +138
4
,49 @@
class LLMEngine:
# Clear outputs for each new scheduler iteration
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
ctx.request_outputs.clear()
...
@@ -2641,7 +2454,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2641,7 +2454,7 @@ index d82d9ad9..254337cb 100644
(seq_group_metadata_list, scheduler_outputs,
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule()
- ) = self.scheduler[virtual_engine].schedule()
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills
, self._finished_transfers
)
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills)
+
+
+
+
+ # Separate remote prefill and running seq groups
+ # Separate remote prefill and running seq groups
...
@@ -2673,7 +2486,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2673,7 +2486,7 @@ index d82d9ad9..254337cb 100644
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +147
8,31
@@
class LLMEngine:
@@ -1383,9 +147
6,29
@@
class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
virtual_engine]
...
@@ -2687,11 +2500,9 @@ index d82d9ad9..254337cb 100644
...
@@ -2687,11 +2500,9 @@ index d82d9ad9..254337cb 100644
+ req_id = scheduled_seq_group.seq_group.request_id
+ req_id = scheduled_seq_group.seq_group.request_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id]
+ block_table = seq_group_metadata.block_tables[seq_id]
+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
+ memory_transfer_req = MemoryTransferRequest(
+ memory_transfer_req = MemoryTransferRequest(
+ request_id=req_id,
+ request_id=req_id,
+ src_block_ids=block_table,
+ src_block_ids=block_table,
+ staging_block_ids=staging_block_ids,
+ dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_engine_id=remote_prefill_params.decode_engine_id,
+ dst_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id,
+ notify_msg=req_id,
...
@@ -2701,13 +2512,13 @@ index d82d9ad9..254337cb 100644
...
@@ -2701,13 +2512,13 @@ index d82d9ad9..254337cb 100644
+
+
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs
+
+
+ outputs, request_notif_counter
, request_done_counter
= self.model_executor.execute_model(
+ outputs, request_notif_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req)
execute_model_req=execute_model_req)
-
-
# We need to do this here so that last step's sampled_token_ids can
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
if self.scheduler_config.is_multi_step:
@@ -1396,7 +15
13
,2
6
@@
class LLMEngine:
@@ -1396,7 +15
09
,2
0
@@
class LLMEngine:
if len(ctx.output_queue) > 0:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
self._process_model_outputs(ctx=ctx)
# No outputs in this case
# No outputs in this case
...
@@ -2718,7 +2529,7 @@ index d82d9ad9..254337cb 100644
...
@@ -2718,7 +2529,7 @@ index d82d9ad9..254337cb 100644
+ blocks_to_swap_out=[],
+ blocks_to_swap_out=[],
+ blocks_to_copy=[])
+ blocks_to_copy=[])
+
+
+ outputs, request_notif_counter
, request_done_counter
= self.model_executor.execute_model(
+ outputs, request_notif_counter = self.model_executor.execute_model(
+ execute_model_req=execute_model_req)
+ execute_model_req=execute_model_req)
+
+
+ for req_id, notif_count in request_notif_counter.items():
+ for req_id, notif_count in request_notif_counter.items():
...
@@ -2726,16 +2537,10 @@ index d82d9ad9..254337cb 100644
...
@@ -2726,16 +2537,10 @@ index d82d9ad9..254337cb 100644
+ if self._request_notif_counter[req_id] > -1:
+ if self._request_notif_counter[req_id] > -1:
+ self._finished_prefills.add(req_id)
+ self._finished_prefills.add(req_id)
+ del self._request_notif_counter[req_id]
+ del self._request_notif_counter[req_id]
+
+ for req_id, done_count in request_done_counter.items():
+ self._request_done_counter[req_id] += done_count
+ if self._request_done_counter[req_id] > -1:
+ self._finished_transfers.add(req_id)
+ del self._request_done_counter[req_id]
# Finish the current step for all the sequence groups.
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
if self.scheduler_config.is_multi_step:
@@ -1456,7 +15
9
2,7 @@
class LLMEngine:
@@ -1456,7 +15
8
2,7 @@
class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
self.model_executor.stop_remote_worker_execution_loop()
...
@@ -2813,7 +2618,7 @@ index 3cf1850e..6b90ece7 100644
...
@@ -2813,7 +2618,7 @@ index 3cf1850e..6b90ece7 100644
+ kv_active_blocks: int
+ kv_active_blocks: int
+ kv_total_blocks: int
+ kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..
c501e4c8
100644
index 85b5f31e..
d33d546a
100644
--- a/vllm/engine/multiprocessing/client.py
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
@@ -8,6 +8,7 @@
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...
@@ -2895,7 +2700,7 @@ index 85b5f31e..c501e4c8 100644
...
@@ -2895,7 +2700,7 @@ index 85b5f31e..c501e4c8 100644
+
+
+ @property
+ @property
+ def using_nixl_connector(self) -> bool:
+ def using_nixl_connector(self) -> bool:
+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "
Triton
NixlConnector"
+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "
Dynemo
NixlConnector"
+
+
@staticmethod
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
def is_unsupported_config(engine_args: AsyncEngineArgs):
...
@@ -3346,10 +3151,10 @@ index 786380c3..56a7cf89 100644
...
@@ -3346,10 +3151,10 @@ index 786380c3..56a7cf89 100644
"""The output data of one completion output of a request.
"""The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644
new file mode 100644
index 00000000..
957f55de
index 00000000..
03f02006
--- /dev/null
--- /dev/null
+++ b/vllm/remote_prefill.py
+++ b/vllm/remote_prefill.py
@@ -0,0 +1,5
4
@@
@@ -0,0 +1,5
3
@@
+from dataclasses import dataclass
+from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine
+from typing import Callable, Optional, List, Coroutine
+
+
...
@@ -3387,7 +3192,6 @@ index 00000000..957f55de
...
@@ -3387,7 +3192,6 @@ index 00000000..957f55de
+ """
+ """
+ request_id: str
+ request_id: str
+ src_block_ids: List[int]
+ src_block_ids: List[int]
+ staging_block_ids: List[int]
+ dst_block_ids: List[int]
+ dst_block_ids: List[int]
+ dst_engine_id: str
+ dst_engine_id: str
+ notify_msg: str
+ notify_msg: str
...
@@ -3531,7 +3335,7 @@ index 12baecde..cbada27f 100644
...
@@ -3531,7 +3335,7 @@ index 12baecde..cbada27f 100644
if self.vllm_config.kv_transfer_config is None:
if self.vllm_config.kv_transfer_config is None:
return False
return False
+
+
+ if self.vllm_config.kv_transfer_config.kv_connector == "
Triton
NixlConnector":
+ if self.vllm_config.kv_transfer_config.kv_connector == "
Dynemo
NixlConnector":
+ return False
+ return False
prefill_meta = model_input.attn_metadata.prefill_metadata
prefill_meta = model_input.attn_metadata.prefill_metadata
...
@@ -3541,13 +3345,13 @@ index 12baecde..cbada27f 100644
...
@@ -3541,13 +3345,13 @@ index 12baecde..cbada27f 100644
if self.vllm_config.kv_transfer_config is None:
if self.vllm_config.kv_transfer_config is None:
return False
return False
+
+
+ if self.vllm_config.kv_transfer_config.kv_connector == "
Triton
NixlConnector":
+ if self.vllm_config.kv_transfer_config.kv_connector == "
Dynemo
NixlConnector":
+ return False
+ return False
prefill_meta = model_input.attn_metadata.prefill_metadata
prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..
1b8515bf
100644
index 582aa460..
ffb7b403
100644
--- a/vllm/worker/worker.py
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@
@@ -2,7 +2,7 @@
...
@@ -3563,7 +3367,7 @@ index 582aa460..1b8515bf 100644
...
@@ -3563,7 +3367,7 @@ index 582aa460..1b8515bf 100644
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
WorkerInput)
+from vllm.distributed.device_communicators.nixl import
Triton
NixlConnector
+from vllm.distributed.device_communicators.nixl import
Dynemo
NixlConnector
+
+
logger = init_logger(__name__)
logger = init_logger(__name__)
...
@@ -3577,7 +3381,7 @@ index 582aa460..1b8515bf 100644
...
@@ -3577,7 +3381,7 @@ index 582aa460..1b8515bf 100644
+ # TODO ptarasiewicz nixl can also support DRAM
+ # TODO ptarasiewicz nixl can also support DRAM
+ assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector"
+ assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector"
+
+
+ self.nixl_connector =
Triton
NixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector =
Dynemo
NixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
+ assert len(self.cache_engine) == 1, "Only one cache engine is supported for now"
+ assert len(self.cache_engine) == 1, "Only one cache engine is supported for now"
+ self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache)
+ self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache)
+ return self.nixl_connector.agent_name
+ return self.nixl_connector.agent_name
...
@@ -3588,13 +3392,13 @@ index 582aa460..1b8515bf 100644
...
@@ -3588,13 +3392,13 @@ index 582aa460..1b8515bf 100644
+
+
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata
, len(agents_metadata)
) # TODO ptarasiewicz: rank or local_rank?
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata
[self.local_rank]
) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr
[self.local_rank]
)
+ return agent_name
+ return agent_name
+
+
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name
[self.local_rank]
, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+
+
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
...
@@ -3602,8 +3406,8 @@ index 582aa460..1b8515bf 100644
...
@@ -3602,8 +3406,8 @@ index 582aa460..1b8515bf 100644
+
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ if worker_input.src_block_ids is not None:
+ if worker_input.src_block_ids is not None:
+ for src_block_ids,
staging_block_ids,
dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids,
worker_input.staging_block_ids,
worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ for src_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids,
staging_block_ids,
dst_block_ids, dst_engine_id, notify_msg)
+ self.nixl_connector.transfer_mem(src_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+
+
+ def shutdown_nixl(self) -> None:
+ def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
...
@@ -3621,12 +3425,11 @@ index 582aa460..1b8515bf 100644
...
@@ -3621,12 +3425,11 @@ index 582aa460..1b8515bf 100644
return WorkerInput(
return WorkerInput(
num_seq_groups=num_seq_groups,
num_seq_groups=num_seq_groups,
@@ -375,6 +416,1
1
@@
class Worker(LocalOrDistributedWorkerBase):
@@ -375,6 +416,1
0
@@
class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
virtual_engine=virtual_engine,
num_steps=num_steps,
num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs],
...
@@ -3634,7 +3437,7 @@ index 582aa460..1b8515bf 100644
...
@@ -3634,7 +3437,7 @@ index 582aa460..1b8515bf 100644
@torch.inference_mode()
@torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..
ecb68530
100644
index 819b81fb..
d9c039eb
100644
--- a/vllm/worker/worker_base.py
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
@@ -9,6 +9,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
...
@@ -3649,7 +3452,7 @@ index 819b81fb..ecb68530 100644
...
@@ -3649,7 +3452,7 @@ index 819b81fb..ecb68530 100644
from vllm.worker.model_runner_base import (BroadcastableModelInput,
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerBase,
ModelRunnerInputBase)
ModelRunnerInputBase)
+from vllm.distributed.device_communicators.nixl import
Triton
NixlConnector
+from vllm.distributed.device_communicators.nixl import
Dynemo
NixlConnector
logger = init_logger(__name__)
logger = init_logger(__name__)
...
@@ -3657,17 +3460,16 @@ index 819b81fb..ecb68530 100644
...
@@ -3657,17 +3460,16 @@ index 819b81fb..ecb68530 100644
from vllm.platforms import current_platform
from vllm.platforms import current_platform
self.current_platform = current_platform
self.current_platform = current_platform
+ self.nixl_connector: Optional[
Triton
NixlConnector] = None
+ self.nixl_connector: Optional[
Dynemo
NixlConnector] = None
+
+
@abstractmethod
@abstractmethod
def init_device(self) -> None:
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
"""Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,1
2
@@
class WorkerInput:
@@ -216,6 +220,1
1
@@
class WorkerInput:
virtual_engine: int = 0
virtual_engine: int = 0
num_steps: int = 1
num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None
+ src_block_ids: Optional[List[List[int]]] = None
+ staging_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None
+ dst_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None
...
@@ -3675,31 +3477,29 @@ index 819b81fb..ecb68530 100644
...
@@ -3675,31 +3477,29 @@ index 819b81fb..ecb68530 100644
@classmethod
@classmethod
def from_broadcasted_tensor_dict(
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
cls: Type["WorkerInput"],
@@ -232,6 +24
2
,1
1
@@
class WorkerInput:
@@ -232,6 +24
1
,1
0
@@
class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"),
+ src_block_ids=tensor_dict.pop("src_block_ids"),
+ staging_block_ids=tensor_dict.pop("staging_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"),
+ notify_msg=tensor_dict.pop("notify_msg"),
)
)
def as_broadcastable_tensor_dict(
def as_broadcastable_tensor_dict(
@@ -246,6 +2
61
,1
1
@@
class WorkerInput:
@@ -246,6 +2
59
,1
0
@@
class WorkerInput:
"blocks_to_copy": self.blocks_to_copy,
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
"num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids,
+ "src_block_ids": self.src_block_ids,
+ "staging_block_ids": self.staging_block_ids,
+ "dst_block_ids": self.dst_block_ids,
+ "dst_block_ids": self.dst_block_ids,
+ "dst_engine_id": self.dst_engine_id,
+ "dst_engine_id": self.dst_engine_id,
+ "notify_msg": self.notify_msg,
+ "notify_msg": self.notify_msg,
}
}
return tensor_dict
return tensor_dict
@@ -316,13 +33
6
,16 @@
class LocalOrDistributedWorkerBase(WorkerBase):
@@ -316,13 +33
3
,16 @@
class LocalOrDistributedWorkerBase(WorkerBase):
return None
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
...
@@ -3721,7 +3521,7 @@ index 819b81fb..ecb68530 100644
...
@@ -3721,7 +3521,7 @@ index 819b81fb..ecb68530 100644
def _get_driver_input_and_broadcast(
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
self, execute_model_req: ExecuteModelRequest
@@ -396,49 +41
9,87
@@
class LocalOrDistributedWorkerBase(WorkerBase):
@@ -396,49 +41
6,79
@@
class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input)
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
...
@@ -3818,7 +3618,7 @@ index 819b81fb..ecb68530 100644
...
@@ -3818,7 +3618,7 @@ index 819b81fb..ecb68530 100644
+ else:
+ else:
+ for i in range(1, get_tp_group().world_size):
+ for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i))
+ all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int)
+ request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs:
+ for notifs in all_new_notifs:
+ for req_ids in notifs.values():
+ for req_ids in notifs.values():
...
@@ -3827,20 +3627,12 @@ index 819b81fb..ecb68530 100644
...
@@ -3827,20 +3627,12 @@ index 819b81fb..ecb68530 100644
+
+
+ if request_notif_counter:
+ if request_notif_counter:
+ logger.debug("Request notif counter: %s", request_notif_counter)
+ logger.debug("Request notif counter: %s", request_notif_counter)
+
+ request_done_counter = defaultdict(int)
+ for req_id in self.nixl_connector.get_done_tranfers():
+ request_done_counter[req_id] += 1
+
+ if request_done_counter:
+ logger.debug("Request done counter: %s", request_done_counter)
+
+ else:
+ else:
+ request_notif_counter = {}
+ request_notif_counter = {}
+ request_done_counter = {}
# output is List[SamplerOutput]
# output is List[SamplerOutput]
- return output
- return output
+ return output, request_notif_counter
, request_done_counter
+ return output, request_notif_counter
+
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ pass
+ pass
...
...
deploy/compoundai/sdk/src/compoundai/cli/serve_nova.py
View file @
1af7433b
...
@@ -26,7 +26,8 @@ import typing as t
...
@@ -26,7 +26,8 @@ import typing as t
from
typing
import
Any
from
typing
import
Any
import
click
import
click
from
triton_distributed_rs
import
DistributedRuntime
,
triton_endpoint
,
triton_worker
from
dynemo.runtime
import
DistributedRuntime
,
dynemo_endpoint
,
dynemo_worker
logger
=
logging
.
getLogger
(
"compoundai.serve.nova"
)
logger
=
logging
.
getLogger
(
"compoundai.serve.nova"
)
...
@@ -102,7 +103,7 @@ def main(
...
@@ -102,7 +103,7 @@ def main(
server_context
.
worker_index
=
worker_id
server_context
.
worker_index
=
worker_id
class_instance
=
service
.
inner
()
class_instance
=
service
.
inner
()
@
triton
_worker
()
@
dynemo
_worker
()
async
def
worker
(
runtime
:
DistributedRuntime
):
async
def
worker
(
runtime
:
DistributedRuntime
):
if
service_name
and
service_name
!=
service
.
name
:
if
service_name
and
service_name
!=
service
.
name
:
server_context
.
service_type
=
"service"
server_context
.
service_type
=
"service"
...
@@ -157,12 +158,12 @@ def main(
...
@@ -157,12 +158,12 @@ def main(
# Bind an instance of inner to the endpoint
# Bind an instance of inner to the endpoint
bound_method
=
endpoint
.
func
.
__get__
(
class_instance
)
bound_method
=
endpoint
.
func
.
__get__
(
class_instance
)
# Only pass request type for now, use Any for response
# Only pass request type for now, use Any for response
# TODO: Handle a
triton
_endpoint not having types
# TODO: Handle a
dynemo
_endpoint not having types
# TODO: Handle multiple endpoints in a single component
# TODO: Handle multiple endpoints in a single component
triton
_wrapped_method
=
triton
_endpoint
(
endpoint
.
request_type
,
Any
)(
dynemo
_wrapped_method
=
dynemo
_endpoint
(
endpoint
.
request_type
,
Any
)(
bound_method
bound_method
)
)
result
=
await
td_endpoint
.
serve_endpoint
(
triton
_wrapped_method
)
result
=
await
td_endpoint
.
serve_endpoint
(
dynemo
_wrapped_method
)
# WARNING: unreachable code :( because serve blocks
# WARNING: unreachable code :( because serve blocks
logger
.
info
(
f
"[
{
run_id
}
] Result:
{
result
}
"
)
logger
.
info
(
f
"[
{
run_id
}
] Result:
{
result
}
"
)
logger
.
info
(
f
"[
{
run_id
}
] Registered endpoint '
{
name
}
'"
)
logger
.
info
(
f
"[
{
run_id
}
] Registered endpoint '
{
name
}
'"
)
...
...
deploy/compoundai/sdk/src/compoundai/sdk/decorators.py
View file @
1af7433b
...
@@ -50,7 +50,7 @@ class NovaEndpoint:
...
@@ -50,7 +50,7 @@ class NovaEndpoint:
if
isinstance
(
args
[
1
],
(
str
,
dict
)):
if
isinstance
(
args
[
1
],
(
str
,
dict
)):
args
[
1
]
=
self
.
request_type
.
parse_obj
(
args
[
1
])
# type: ignore
args
[
1
]
=
self
.
request_type
.
parse_obj
(
args
[
1
])
# type: ignore
# Convert Pydantic model to dict before passing to
triton
# Convert Pydantic model to dict before passing to
dynemo
if
len
(
args
)
>
1
and
isinstance
(
args
[
1
],
BaseModel
):
if
len
(
args
)
>
1
and
isinstance
(
args
[
1
],
BaseModel
):
args
=
list
(
args
)
# type: ignore
args
=
list
(
args
)
# type: ignore
args
[
1
]
=
args
[
1
].
model_dump
()
# type: ignore
args
[
1
]
=
args
[
1
].
model_dump
()
# type: ignore
...
...
deploy/compoundai/sdk/src/compoundai/sdk/dependency.py
View file @
1af7433b
...
@@ -72,9 +72,9 @@ class NovaClient:
...
@@ -72,9 +72,9 @@ class NovaClient:
else
:
else
:
# Create nova worker if no runtime
# Create nova worker if no runtime
from
triton_distributed_rs
import
DistributedRuntime
,
triton
_worker
from
dynemo.runtime
import
DistributedRuntime
,
dynemo
_worker
@
triton
_worker
()
@
dynemo
_worker
()
async
def
stream_worker
(
runtime
:
DistributedRuntime
):
async
def
stream_worker
(
runtime
:
DistributedRuntime
):
try
:
try
:
# Store runtime for future use
# Store runtime for future use
...
...
examples/python_rs/llm/tensorrt_llm/README.md
View file @
1af7433b
...
@@ -90,14 +90,14 @@ Note: NATS and ETCD servers should be running and accessible from the container
...
@@ -90,14 +90,14 @@ Note: NATS and ETCD servers should be running and accessible from the container
Run the server logging (with debug level logging):
Run the server logging (with debug level logging):
```
bash
```
bash
TRD
_LOG
=
DEBUG http &
DYN
_LOG
=
DEBUG http &
```
```
By default the server will run on port 8080.
By default the server will run on port 8080.
Add model to the server:
Add model to the server:
```
bash
```
bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0
triton-init
.tensorrt-llm.chat/completions
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0
dynemo
.tensorrt-llm.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0
triton-init
.tensorrt-llm.completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0
dynemo
.tensorrt-llm.completions
```
```
#### 2. Workers
#### 2. Workers
...
@@ -214,14 +214,14 @@ Run the container interactively with the following command:
...
@@ -214,14 +214,14 @@ Run the container interactively with the following command:
Run the server logging (with debug level logging):
Run the server logging (with debug level logging):
```
bash
```
bash
TRD
_LOG
=
DEBUG http &
DYN
_LOG
=
DEBUG http &
```
```
By default the server will run on port 8080.
By default the server will run on port 8080.
Add model to the server:
Add model to the server:
```
bash
```
bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0
triton-init
.router.chat/completions
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0
dynemo
.router.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0
triton-init
.router.completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0
dynemo
.router.completions
```
```
#### 2. Workers
#### 2. Workers
...
...
examples/python_rs/llm/tensorrt_llm/common/client.py
View file @
1af7433b
...
@@ -19,12 +19,12 @@ import asyncio
...
@@ -19,12 +19,12 @@ import asyncio
import
uvloop
import
uvloop
from
triton_distributed
.runtime
import
DistributedRuntime
,
triton
_worker
from
dynemo
.runtime
import
DistributedRuntime
,
dynemo
_worker
from
.protocol
import
Request
from
.protocol
import
Request
@
triton
_worker
()
@
dynemo
_worker
()
async
def
worker
(
async
def
worker
(
runtime
:
DistributedRuntime
,
runtime
:
DistributedRuntime
,
component
:
str
,
component
:
str
,
...
@@ -38,7 +38,7 @@ async def worker(
...
@@ -38,7 +38,7 @@ async def worker(
"""
"""
# create client
# create client
client
=
(
client
=
(
await
runtime
.
namespace
(
"
triton-init
"
)
await
runtime
.
namespace
(
"
dynemo
"
)
.
component
(
component
)
.
component
(
component
)
.
endpoint
(
"generate"
)
.
endpoint
(
"generate"
)
.
client
()
.
client
()
...
...
Prev
1
2
3
4
5
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment