Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
1af7433b
Commit
1af7433b
authored
Mar 05, 2025
by
Neelay Shah
Committed by
GitHub
Mar 05, 2025
Browse files
refactor: rename triton_distributed to dynemo (#22)
Co-authored-by:
Graham King
<
grahamk@nvidia.com
>
parent
ee4ef06b
Changes
165
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
383 additions
and
649 deletions
+383
-649
.dockerignore
.dockerignore
+1
-1
.github/workflows/pre-merge.yml
.github/workflows/pre-merge.yml
+0
-57
ATTRIBUTIONS.md
ATTRIBUTIONS.md
+1
-1
README.md
README.md
+1
-1
applications/llm/count/Cargo.lock
applications/llm/count/Cargo.lock
+95
-95
applications/llm/count/Cargo.toml
applications/llm/count/Cargo.toml
+3
-3
applications/llm/count/README.md
applications/llm/count/README.md
+6
-6
applications/llm/count/src/bin/mock_worker.rs
applications/llm/count/src/bin/mock_worker.rs
+5
-5
applications/llm/count/src/lib.rs
applications/llm/count/src/lib.rs
+4
-6
applications/llm/count/src/main.rs
applications/llm/count/src/main.rs
+3
-3
container/Dockerfile
container/Dockerfile
+17
-17
container/Dockerfile.vllm
container/Dockerfile.vllm
+18
-18
container/Dockerfile.vllm_nixl
container/Dockerfile.vllm_nixl
+18
-18
container/deps/clone_tensorrtllm.sh
container/deps/clone_tensorrtllm.sh
+5
-5
container/deps/vllm/vllm_v0.7.2-dynemo-kv-disagg-patch.patch
container/deps/vllm/vllm_v0.7.2-dynemo-kv-disagg-patch.patch
+188
-396
deploy/compoundai/sdk/src/compoundai/cli/serve_nova.py
deploy/compoundai/sdk/src/compoundai/cli/serve_nova.py
+6
-5
deploy/compoundai/sdk/src/compoundai/sdk/decorators.py
deploy/compoundai/sdk/src/compoundai/sdk/decorators.py
+1
-1
deploy/compoundai/sdk/src/compoundai/sdk/dependency.py
deploy/compoundai/sdk/src/compoundai/sdk/dependency.py
+2
-2
examples/python_rs/llm/tensorrt_llm/README.md
examples/python_rs/llm/tensorrt_llm/README.md
+6
-6
examples/python_rs/llm/tensorrt_llm/common/client.py
examples/python_rs/llm/tensorrt_llm/common/client.py
+3
-3
No files found.
.dockerignore
View file @
1af7433b
...
...
@@ -19,7 +19,7 @@
**/*.plan
**/.cache/*
**/*onnx*
# Engine must be allowed because code contains
triton_distributed
_engine.py
# Engine must be allowed because code contains
dynemo
_engine.py
**/*tensorrtllm_engines*
**/*tensorrtllm_models*
**/*tensorrtllm_checkpoints*
...
...
.github/workflows/pre-merge.yml
View file @
1af7433b
...
...
@@ -22,25 +22,6 @@ on:
jobs
:
# icp_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: ./icp/protos/gen_python.sh
# - run: pytest --verbose icp
# timeout-minutes: 3
pre-commit
:
runs-on
:
ubuntu-latest
permissions
:
...
...
@@ -52,41 +33,3 @@ jobs:
timeout-minutes
:
3
# providers_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# PROTO_OUT: /python/icp/protos
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: pytest --verbose providers
# worker_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# PROTO_OUT: /python/icp/protos
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: ./icp/protos/gen_python.sh
# - run: pytest -p no:warnings --verbose worker/python/tests
# timeout-minutes: 2
ATTRIBUTIONS.md
View file @
1af7433b
...
...
@@ -17,7 +17,7 @@ limitations under the License.
# Open Source License Attribution
Triton Distributed
uses Open Source components. You can find the details of these open-source projects along with license information below.
Dynemo
uses Open Source components. You can find the details of these open-source projects along with license information below.
We are grateful to the developers for their contributions to open source and acknowledge these below.
## nats-py - [Apache License 2.0](https://github.com/nats-io/nats.py/blob/main/LICENSE)
...
...
README.md
View file @
1af7433b
...
...
@@ -71,7 +71,7 @@ The run script offers a few common workflows:
1.
Running a command in a container and exiting.
```
./container/run.sh -- python3 -c "import
triton_distributed.runtime; help(triton_distributed
.runtime)"
./container/run.sh -- python3 -c "import
dynemo.runtime; help(dynemo
.runtime)"
```
2.
Starting an interactive shell.
...
...
applications/llm/count/Cargo.lock
View file @
1af7433b
...
...
@@ -737,6 +737,8 @@ version = "0.1.0"
dependencies = [
"axum 0.6.20",
"clap",
"dynemo-llm",
"dynemo-runtime",
"opentelemetry",
"opentelemetry-prometheus",
"prometheus",
...
...
@@ -747,8 +749,6 @@ dependencies = [
"thiserror 1.0.69",
"tokio",
"tracing",
"triton-distributed-llm",
"triton-distributed-runtime",
]
[[package]]
...
...
@@ -1024,6 +1024,99 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "dynemo-llm"
version = "0.2.1"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
"bindgen",
"blake3",
"bs62",
"bytes",
"chrono",
"cmake",
"derive_builder",
"dynemo-runtime",
"either",
"erased-serde",
"futures",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"serde_repr",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"unicode-segmentation",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "dynemo-runtime"
version = "0.2.1"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "ed25519"
version = "2.2.3"
...
...
@@ -4232,99 +4325,6 @@ dependencies = [
"tracing-serde",
]
[[package]]
name = "triton-distributed-llm"
version = "0.2.1"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
"bindgen",
"blake3",
"bs62",
"bytes",
"chrono",
"cmake",
"derive_builder",
"either",
"erased-serde",
"futures",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"serde_repr",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"triton-distributed-runtime",
"unicode-segmentation",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "triton-distributed-runtime"
version = "0.2.1"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "try-lock"
version = "0.2.5"
...
...
applications/llm/count/Cargo.toml
View file @
1af7433b
...
...
@@ -21,8 +21,8 @@ license = "Apache-2.0"
[dependencies]
# local
triton-distributed
-runtime
=
{
path
=
"../../../lib/runtime"
}
triton-distributed
-llm
=
{
path
=
"../../../lib/llm"
}
dynemo
-runtime
=
{
path
=
"../../../lib/runtime"
}
dynemo
-llm
=
{
path
=
"../../../lib/llm"
}
# workspace - todo
...
...
applications/llm/count/README.md
View file @
1af7433b
...
...
@@ -8,17 +8,17 @@ the services associated with that endpoint, do some postprocessing on them,
and then publish an event with the postprocessed data.
```
bash
# For more details, try
TRD
_LOG=debug
TRD
_LOG
=
info cargo run
--bin
count
--
--namespace
triton-init
--component
backend
--endpoint
generate
# For more details, try
DYN
_LOG=debug
DYN
_LOG
=
info cargo run
--bin
count
--
--namespace
dynemo
--component
backend
--endpoint
generate
# 2025-02-26T18:45:05.467026Z INFO count: Creating unique instance of Count at
triton-init
/components/count/instance
# 2025-02-26T18:45:05.472146Z INFO count: Scraping service
triton
_init_backend_720278f8 and filtering on subject
triton
_init_backend_720278f8.generate
# 2025-02-26T18:45:05.467026Z INFO count: Creating unique instance of Count at
dynemo
/components/count/instance
# 2025-02-26T18:45:05.472146Z INFO count: Scraping service
dynemo
_init_backend_720278f8 and filtering on subject
dynemo
_init_backend_720278f8.generate
# ...
```
With no matching endpoints running, you should see warnings in the logs:
```
bash
2025-02-26T18:45:06.474161Z WARN count: No endpoints found matching subject
triton
_init_backend_720278f8.generate
2025-02-26T18:45:06.474161Z WARN count: No endpoints found matching subject
dynemo
_init_backend_720278f8.generate
```
To see metrics published to a matching endpoint, you can use the
...
...
@@ -35,7 +35,7 @@ since the endpoint will automatically get discovered.
When stats are found from the target endpoints being listened on, count will
aggregate and publish some metrics as both an event and to a prometheus web server:
```
2025-02-28T04:05:58.077901Z INFO count: Aggregated metrics: ProcessedEndpoints { endpoints: [Endpoint { name: "worker-7587884888253033398", subject: "
triton
_init_backend_720278f8.generate-694d951a80e06bb6", data: ForwardPassMetrics { request_active_slots: 58, request_total_slots: 100, kv_active_blocks: 77, kv_total_blocks: 100 } }, Endpoint { name: "worker-7587884888253033401", subject: "
triton
_init_backend_720278f8.generate-694d951a80e06bb9", data: ForwardPassMetrics { request_active_slots: 71, request_total_slots: 100, kv_active_blocks: 29, kv_total_blocks: 100 } }], worker_ids: [7587884888253033398, 7587884888253033401], load_avg: 53.0, load_std: 24.0 }
2025-02-28T04:05:58.077901Z INFO count: Aggregated metrics: ProcessedEndpoints { endpoints: [Endpoint { name: "worker-7587884888253033398", subject: "
dynemo
_init_backend_720278f8.generate-694d951a80e06bb6", data: ForwardPassMetrics { request_active_slots: 58, request_total_slots: 100, kv_active_blocks: 77, kv_total_blocks: 100 } }, Endpoint { name: "worker-7587884888253033401", subject: "
dynemo
_init_backend_720278f8.generate-694d951a80e06bb9", data: ForwardPassMetrics { request_active_slots: 71, request_total_slots: 100, kv_active_blocks: 29, kv_total_blocks: 100 } }], worker_ids: [7587884888253033398, 7587884888253033401], load_avg: 53.0, load_std: 24.0 }
```
To see the metrics being published in prometheus format, you can run:
...
...
applications/llm/count/src/bin/mock_worker.rs
View file @
1af7433b
...
...
@@ -13,10 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
rand
::
Rng
;
use
std
::
sync
::
Arc
;
use
triton_distributed_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
triton_distributed_runtime
::{
use
dynemo_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
dynemo_runtime
::{
logging
,
pipeline
::{
async_trait
,
network
::
Ingress
,
AsyncEngine
,
AsyncEngineContextProvider
,
Error
,
ManyOut
,
...
...
@@ -25,6 +23,8 @@ use triton_distributed_runtime::{
protocols
::
annotated
::
Annotated
,
stream
,
DistributedRuntime
,
Result
,
Runtime
,
Worker
,
};
use
rand
::
Rng
;
use
std
::
sync
::
Arc
;
fn
main
()
->
Result
<
()
>
{
logging
::
init
();
...
...
@@ -69,7 +69,7 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
// we must first create a service, then we can attach one more more endpoints
runtime
.namespace
(
"
triton-init
"
)
?
.namespace
(
"
dynemo
"
)
?
.component
(
"backend"
)
?
.service_builder
()
.create
()
...
...
applications/llm/count/src/lib.rs
View file @
1af7433b
...
...
@@ -20,13 +20,11 @@ use prometheus::register_gauge_vec;
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
net
::
SocketAddr
;
use
triton_distributed
_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
triton_distributed
_llm
::
kv_router
::
scheduler
::
Endpoint
;
use
triton_distributed
_llm
::
kv_router
::
scoring
::
ProcessedEndpoints
;
use
dynemo
_llm
::
kv_router
::
protocols
::
ForwardPassMetrics
;
use
dynemo
_llm
::
kv_router
::
scheduler
::
Endpoint
;
use
dynemo
_llm
::
kv_router
::
scoring
::
ProcessedEndpoints
;
use
triton_distributed_runtime
::{
distributed
::
Component
,
service
::
EndpointInfo
,
utils
::
Duration
,
Result
,
};
use
dynemo_runtime
::{
distributed
::
Component
,
service
::
EndpointInfo
,
utils
::
Duration
,
Result
};
/// Configuration for LLM worker load capacity metrics
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
...
...
applications/llm/count/src/main.rs
View file @
1af7433b
...
...
@@ -24,7 +24,7 @@
//! - KV Cache Blocks: [Active, Total]
use
clap
::
Parser
;
use
triton_distributed
_runtime
::{
use
dynemo
_runtime
::{
error
,
logging
,
traits
::
events
::
EventPublisher
,
utils
::{
Duration
,
Instant
},
...
...
@@ -50,7 +50,7 @@ struct Args {
endpoint
:
String
,
/// Namespace to operate in
#[arg(long,
env
=
"
TRD
_NAMESPACE"
,
default_value
=
"
triton-init
"
)]
#[arg(long,
env
=
"
DYN
_NAMESPACE"
,
default_value
=
"
dynemo
"
)]
namespace
:
String
,
/// Polling interval in seconds (minimum 1 second)
...
...
@@ -155,7 +155,7 @@ mod tests {
#[test]
fn
test_namespace_from_env
()
{
env
::
set_var
(
"
TRD
_NAMESPACE"
,
"test-namespace"
);
env
::
set_var
(
"
DYN
_NAMESPACE"
,
"test-namespace"
);
let
args
=
Args
::
parse_from
([
"count"
,
"--component"
,
"comp"
,
"--endpoint"
,
"end"
]);
assert_eq!
(
args
.namespace
,
"test-namespace"
);
}
...
...
container/Dockerfile
View file @
1af7433b
...
...
@@ -16,7 +16,7 @@
ARG
BASE_IMAGE="nvcr.io/nvidia/tritonserver"
ARG
BASE_IMAGE_TAG="25.01-py3"
FROM
${BASE_IMAGE}:${BASE_IMAGE_TAG} AS
triton-distributed
FROM
${BASE_IMAGE}:${BASE_IMAGE_TAG} AS
dynemo
# TODO: non root user by default
...
...
@@ -34,7 +34,7 @@ RUN rustup toolchain install 1.85.0-x86_64-unknown-linux-gnu
# Install OpenAI-compatible frontend and its dependencies from triton server
# repository. These are used to have a consistent interface, schema, and FastAPI
# app between Triton Core and
Triton Distributed
implementations.
# app between Triton Core and
Dynemo
implementations.
ARG
OPENAI_SERVER_TAG="r25.01"
RUN
mkdir
-p
/opt/tritonserver/python
&&
\
cd
/opt/tritonserver/python
&&
\
...
...
@@ -78,7 +78,7 @@ ARG TENSORRTLLM_SKIP_CLONE=
ENV
FRAMEWORK=${FRAMEWORK}
RUN
--mount
=
type
=
bind
,source
=
./container/deps/requirements.tensorrtllm.txt,target
=
/tmp/requirements.txt
\
--mount
=
type
=
bind
,source
=
./container/deps/clone_tensorrtllm.sh,target
=
/tmp/clone_tensorrtllm.sh
\
if
[[
"
$FRAMEWORK
"
==
"TENSORRTLLM"
]]
;
then
pip
install
--timeout
=
2000
-r
/tmp/requirements.txt
;
if
[
${
TENSORRTLLM_SKIP_CLONE
}
-ne
1
]
;
then
/tmp/clone_tensorrtllm.sh
--tensorrtllm-backend-repo-tag
${
TENSORRTLLM_BACKEND_REPO_TAG
}
--tensorrtllm-backend-rebuild
${
TENSORRTLLM_BACKEND_REBUILD
}
--triton-llm-path
/opt/
triton
/llm_binding
;
fi
;
fi
if
[[
"
$FRAMEWORK
"
==
"TENSORRTLLM"
]]
;
then
pip
install
--timeout
=
2000
-r
/tmp/requirements.txt
;
if
[
${
TENSORRTLLM_SKIP_CLONE
}
-ne
1
]
;
then
/tmp/clone_tensorrtllm.sh
--tensorrtllm-backend-repo-tag
${
TENSORRTLLM_BACKEND_REPO_TAG
}
--tensorrtllm-backend-rebuild
${
TENSORRTLLM_BACKEND_REBUILD
}
--triton-llm-path
/opt/
dynemo
/llm_binding
;
fi
;
fi
RUN
--mount
=
type
=
bind
,source
=
./container/deps/requirements.standard.txt,target
=
/tmp/requirements.txt
\
...
...
@@ -106,7 +106,7 @@ ENV VLLM_GENERATE_WORKERS=${VLLM_FRAMEWORK:+1}
ENV
VLLM_BASELINE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_CONTEXT_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_GENERATE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV
VLLM_KV_CAPI_PATH="/opt/
triton
/llm_binding/lib/lib
triton
_llm_capi.so"
ENV
VLLM_KV_CAPI_PATH="/opt/
dynemo
/llm_binding/lib/lib
dynemo
_llm_capi.so"
ENV
PYTHONUNBUFFERED=1
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
...
...
@@ -159,27 +159,27 @@ COPY lib/bindings /workspace/lib/bindings
RUN
cd
lib/bindings/c/
&&
\
cargo build
--release
--locked
&&
cargo doc
--no-deps
# Install uv, create virtualenv for general use, and build
triton_distributed
wheel
# Install uv, create virtualenv for general use, and build
dynemo
wheel
COPY
--from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN
mkdir
/opt/
triton
&&
\
uv venv /opt/
triton
/venv
--python
3.12
&&
\
source
/opt/
triton
/venv/bin/activate
&&
\
RUN
mkdir
/opt/
dynemo
&&
\
uv venv /opt/
dynemo
/venv
--python
3.12
&&
\
source
/opt/
dynemo
/venv/bin/activate
&&
\
uv build
--wheel
--out-dir
/workspace/dist
&&
\
uv pip
install
/workspace/dist/
triton_distributed
*
cp312
*
.whl
uv pip
install
/workspace/dist/
dynemo
*
cp312
*
.whl
# Package the bindings
RUN
mkdir
-p
/opt/
triton
/bindings/wheels
&&
\
mkdir
/opt/
triton
/bindings/lib
&&
\
cp
dist/
triton_distributed
*
cp312
*
.whl /opt/
triton
/bindings/wheels/.
&&
\
cp
lib/bindings/c/target/release/lib
triton_distributed
_llm_capi.so /opt/
triton
/bindings/lib/.
&&
\
cp
-r
lib/bindings/c/include /opt/
triton
/bindings/.
RUN
mkdir
-p
/opt/
dynemo
/bindings/wheels
&&
\
mkdir
/opt/
dynemo
/bindings/lib
&&
\
cp
dist/
dynemo
*
cp312
*
.whl /opt/
dynemo
/bindings/wheels/.
&&
\
cp
lib/bindings/c/target/release/lib
dynemo
_llm_capi.so /opt/
dynemo
/bindings/lib/.
&&
\
cp
-r
lib/bindings/c/include /opt/
dynemo
/bindings/.
# Install
triton_distributed_runtime and triton_distributed_
llm wheels globally in container for tests that
# Install
dynemo.runtime and dynemo.
llm wheels globally in container for tests that
# currently run without virtual environment activated.
# TODO: In future, we may use a virtualenv for everything and remove this.
RUN
pip
install
/opt/
triton
/bindings/wheels/
triton_distributed
*
cp312
*
.whl
RUN
pip
install
/opt/
dynemo
/bindings/wheels/
dynemo
*
cp312
*
.whl
# Copy everything in after install steps to avoid re-running build/install
# Copy everything in after
g
install steps to avoid re-running build/install
# commands on unrelated changes in other dirs.
COPY
. /workspace
...
...
container/Dockerfile.vllm
View file @
1af7433b
...
...
@@ -24,17 +24,17 @@ ENV PATH=/usr/local/bin/etcd/:$PATH
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/
triton
&& \
uv venv /opt/
triton
/venv --python 3.12
RUN mkdir /opt/
dynemo
&& \
uv venv /opt/
dynemo
/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/
triton
/venv
ENV VIRTUAL_ENV=/opt/
dynemo
/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-
triton
-kv-disagg-patch.patch"
ARG VLLM_PATCH="vllm_${VLLM_REF}-
dynemo
-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
...
...
@@ -100,25 +100,25 @@ COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c && \
cargo build --release --locked && cargo doc --no-deps
# Build
triton_distributed
wheel
RUN source /opt/
triton
/venv/bin/activate && \
# Build
dynemo
wheel
RUN source /opt/
dynemo
/venv/bin/activate && \
uv build --wheel --out-dir /workspace/dist && \
uv pip install /workspace/dist/
triton_distributed
*cp312*.whl
uv pip install /workspace/dist/
dynemo
*cp312*.whl
# Package the bindings
RUN mkdir -p /opt/
triton
/bindings/wheels && \
mkdir /opt/
triton
/bindings/lib && \
cp dist/
triton_distributed
*cp312*.whl /opt/
triton
/bindings/wheels/. && \
cp lib/bindings/c/target/release/lib
triton_distributed
_llm_capi.so /opt/
triton
/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/
triton
/bindings/.
RUN mkdir -p /opt/
dynemo
/bindings/wheels && \
mkdir /opt/
dynemo
/bindings/lib && \
cp dist/
dynemo
*cp312*.whl /opt/
dynemo
/bindings/wheels/. && \
cp lib/bindings/c/target/release/lib
dynemo
_llm_capi.so /opt/
dynemo
/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/
dynemo
/bindings/.
# Tell vllm to use the
Triton
LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/
triton
/bindings/lib/lib
triton_distributed
_llm_capi.so"
# Tell vllm to use the
Dynemo
LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/
dynemo
/bindings/lib/lib
dynemo
_llm_capi.so"
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
# FIXME: May want a modification with
triton
-distributed banner on entry
# FIXME: May want a modification with
dynemo
-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
...
...
@@ -136,10 +136,10 @@ RUN apt update -y && \
echo "set -g mouse on" >> /root/.tmux.conf
# Set environment variables
ENV VIRTUAL_ENV=/opt/
triton
/venv
ENV VIRTUAL_ENV=/opt/
dynemo
/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV VLLM_KV_CAPI_PATH="/opt/
triton
/bindings/lib/lib
triton_distributed
_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/
dynemo
/bindings/lib/lib
dynemo
_llm_capi.so"
# Copy binaries
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
...
...
@@ -166,7 +166,7 @@ COPY examples/python_rs/llm/vllm /workspace/examples/python_rs/llm/vllm
WORKDIR /workspace
# FIXME: May want a modification with
triton
-distributed banner on entry
# FIXME: May want a modification with
dynemo
-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
container/Dockerfile.vllm_nixl
View file @
1af7433b
...
...
@@ -150,17 +150,17 @@ ENV PATH=/usr/local/bin/etcd/:$PATH
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/
triton
&& \
uv venv /opt/
triton
/venv --python 3.12
RUN mkdir /opt/
dynemo
&& \
uv venv /opt/
dynemo
/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/
triton
/venv
ENV VIRTUAL_ENV=/opt/
dynemo
/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-
triton
-kv-disagg-patch.patch"
ARG VLLM_PATCH="vllm_${VLLM_REF}-
dynemo
-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
...
...
@@ -225,25 +225,25 @@ COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c && \
cargo build --release --locked && cargo doc --no-deps
# Build
triton_distributed
wheel
RUN source /opt/
triton
/venv/bin/activate && \
# Build
dynemo
wheel
RUN source /opt/
dynemo
/venv/bin/activate && \
uv build --wheel --out-dir /workspace/dist && \
uv pip install /workspace/dist/
triton_distributed
*cp312*.whl
uv pip install /workspace/dist/
dynemo
*cp312*.whl
# Package the bindings
RUN mkdir -p /opt/
triton
/bindings/wheels && \
mkdir /opt/
triton
/bindings/lib && \
cp dist/
triton_distributed
*cp312*.whl /opt/
triton
/bindings/wheels/. && \
cp lib/bindings/c/target/release/lib
triton_distributed
_llm_capi.so /opt/
triton
/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/
triton
/bindings/.
RUN mkdir -p /opt/
dynemo
/bindings/wheels && \
mkdir /opt/
dynemo
/bindings/lib && \
cp dist/
dynemo
*cp312*.whl /opt/
dynemo
/bindings/wheels/. && \
cp lib/bindings/c/target/release/lib
dynemo
_llm_capi.so /opt/
dynemo
/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/
dynemo
/bindings/.
# Tell vllm to use the
Triton
LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/
triton
/bindings/lib/lib
triton_distributed
_llm_capi.so"
# Tell vllm to use the
Dynemo
LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/
dynemo
/bindings/lib/lib
dynemo
_llm_capi.so"
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
# FIXME: May want a modification with
triton
-distributed banner on entry
# FIXME: May want a modification with
dynemo
-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
...
...
@@ -261,10 +261,10 @@ RUN apt update -y && \
echo "set -g mouse on" >> /root/.tmux.conf
# Set environment variables
ENV VIRTUAL_ENV=/opt/
triton
/venv
ENV VIRTUAL_ENV=/opt/
dynemo
/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV VLLM_KV_CAPI_PATH="/opt/
triton
/bindings/lib/lib
triton_distributed
_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/
dynemo
/bindings/lib/lib
dynemo
_llm_capi.so"
# Copy binaries
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
...
...
@@ -291,7 +291,7 @@ COPY examples/python_rs/llm/vllm_nixl /workspace/examples/python_rs/llm/vllm_nix
WORKDIR /workspace
# FIXME: May want a modification with
triton
-distributed banner on entry
# FIXME: May want a modification with
dynemo
-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
container/deps/clone_tensorrtllm.sh
View file @
1af7433b
...
...
@@ -16,7 +16,7 @@
TENSORRTLLM_BACKEND_REPO_TAG
=
TENSORRTLLM_BACKEND_REBUILD
=
TRITON
_LLM_PATH
=
DYNEMO
_LLM_PATH
=
GIT_TOKEN
=
GIT_REPO
=
...
...
@@ -43,9 +43,9 @@ get_options() {
missing_requirement
$1
fi
;;
--
triton
-llm-path
)
--
dynemo
-llm-path
)
if
[
"
$2
"
]
;
then
TRITON
_LLM_PATH
=
$2
DYNEMO
_LLM_PATH
=
$2
shift
else
missing_requirement
$1
...
...
@@ -147,9 +147,9 @@ if [ ! -z ${TENSORRTLLM_BACKEND_REBUILD} ]; then
# Build the backend
(
cd
inflight_batcher_llm/src
\
&&
cmake
-DCMAKE_INSTALL_PREFIX
:PATH
=
`
pwd
`
/install
-DUSE_CXX11_ABI
=
1
-D
TRITON
_LLM_PATH
=
$
TRITON
_LLM_PATH
..
\
&&
cmake
-DCMAKE_INSTALL_PREFIX
:PATH
=
`
pwd
`
/install
-DUSE_CXX11_ABI
=
1
-D
DYNEMO
_LLM_PATH
=
$
DYNEMO
_LLM_PATH
..
\
&&
make
install
\
&&
cp
lib
triton
_tensorrtllm.so /opt/tritonserver/backends/tensorrtllm/
\
&&
cp
lib
dynemo
_tensorrtllm.so /opt/tritonserver/backends/tensorrtllm/
\
&&
cp
trtllmExecutorWorker /opt/tritonserver/backends/tensorrtllm/
\
)
fi
...
...
container/deps/vllm/vllm_v0.7.2-
triton
-kv-disagg-patch.patch
→
container/deps/vllm/vllm_v0.7.2-
dynemo
-kv-disagg-patch.patch
View file @
1af7433b
...
...
@@ -31,7 +31,7 @@ index 9ba49757..a2f88854 100644
f"and `kv_both`")
- if self.kv_connector is not None and self.kv_role is None:
+ if self.kv_connector is not None and self.kv_connector != "
Triton
NixlConnector" and self.kv_role is None:
+ if self.kv_connector is not None and self.kv_connector != "
Dynemo
NixlConnector" and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
...
...
@@ -44,7 +44,7 @@ index 9ba49757..a2f88854 100644
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
+ if self.kv_connector == "
Triton
NixlConnector":
+ if self.kv_connector == "
Dynemo
NixlConnector":
+ return False
return self.kv_connector is not None and self.kv_parallel_size > 1
...
...
@@ -277,7 +277,7 @@ index 00000000..350453cd
+logger = logging.getLogger(__name__)
+
+
+class
Triton
Result:
+class
Dynemo
Result:
+ OK = 0
+ ERR = 1
+
...
...
@@ -290,12 +290,12 @@ index 00000000..350453cd
+
+ try:
+ self.lib = ctypes.CDLL(lib_path)
+ self.lib.
triton
_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.
triton
_llm_init.restype = c_uint32
+ self.lib.
dynemo
_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.
dynemo
_llm_init.restype = c_uint32
+
+ result = self.lib.
triton
_llm_init(namespace.encode(),
+ result = self.lib.
dynemo
_llm_init(namespace.encode(),
+ component.encode(), worker_id)
+ if result ==
Triton
Result.OK:
+ if result ==
Dynemo
Result.OK:
+ logger.info(
+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
+ )
...
...
@@ -306,7 +306,7 @@ index 00000000..350453cd
+ print(f"Failed to load {lib_path}")
+ raise e
+
+ self.lib.
triton
_kv_event_publish_stored.argtypes = [
+ self.lib.
dynemo
_kv_event_publish_stored.argtypes = [
+ ctypes.c_uint64, # event_id
+ ctypes.POINTER(ctypes.c_uint32), # token_ids
+ ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
...
...
@@ -315,14 +315,14 @@ index 00000000..350453cd
+ ctypes.POINTER(ctypes.c_uint64), # parent_hash
+ ctypes.c_uint64, # lora_id
+ ]
+ self.lib.
triton
_kv_event_publish_stored.restype = ctypes.c_uint32 #
triton
_llm_result_t
+ self.lib.
dynemo
_kv_event_publish_stored.restype = ctypes.c_uint32 #
dynemo
_llm_result_t
+
+ self.lib.
triton
_kv_event_publish_removed.argtypes = [
+ self.lib.
dynemo
_kv_event_publish_removed.argtypes = [
+ ctypes.c_uint64, # event_id
+ ctypes.POINTER(ctypes.c_uint64), # block_ids
+ ctypes.c_size_t, # num_blocks
+ ]
+ self.lib.
triton
_kv_event_publish_removed.restype = ctypes.c_uint32 #
triton
_llm_result_t
+ self.lib.
dynemo
_kv_event_publish_removed.restype = ctypes.c_uint32 #
dynemo
_llm_result_t
+
+ self.event_id_counter = 0
+
...
...
@@ -336,7 +336,7 @@ index 00000000..350453cd
+ if parent is not None else None)
+
+ # Publish the event
+ result = self.lib.
triton
_kv_event_publish_stored(
+ result = self.lib.
dynemo
_kv_event_publish_stored(
+ self.event_id_counter, # uint64_t event_id
+ token_ids_arr, # const uint32_t *token_ids
+ num_block_tokens, # const uintptr_t *num_block_tokens
...
...
@@ -346,7 +346,7 @@ index 00000000..350453cd
+ 0, # uint64_t lora_id
+ )
+
+ if result ==
Triton
Result.OK:
+ if result ==
Dynemo
Result.OK:
+ logger.debug(f"Store - Published KV Event: {block.content_hash}")
+ else:
+ logger.debug(
...
...
@@ -355,28 +355,23 @@ index 00000000..350453cd
+ self.event_id_counter += 1
+
+ def enqueue_removed_event(self, block_hash: PrefixHash):
+ result = self.lib.
triton
_kv_event_publish_removed(
+ result = self.lib.
dynemo
_kv_event_publish_removed(
+ self.event_id_counter,
+ (ctypes.c_uint64 * 1)(block_hash),
+ 1,
+ )
+
+ if result ==
Triton
Result.OK:
+ if result ==
Dynemo
Result.OK:
+ logger.debug(f"Remove - Published KV Event: {block_hash}")
+ else:
+ logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
+
+ self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..
abe574d1
100644
index f507847a..
ee20d50c
100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -4,22 +4,22 @@
import enum
import os
import random
import time
+import copy
from collections import deque
@@ -8,18 +8,17 @@
from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
...
...
@@ -398,7 +393,7 @@ index f507847a..abe574d1 100644
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
@@ -325,12 +32
5
,14 @@
class Scheduler:
@@ -325,12 +32
4
,14 @@
class Scheduler:
def __init__(
self,
...
...
@@ -413,7 +408,7 @@ index f507847a..abe574d1 100644
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +35
8
,7 @@
class Scheduler:
@@ -356,6 +35
7
,7 @@
class Scheduler:
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
...
...
@@ -421,7 +416,7 @@ index f507847a..abe574d1 100644
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +37
4
,1
6
@@
class Scheduler:
@@ -371,6 +37
3
,1
4
@@
class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
...
...
@@ -429,8 +424,6 @@ index f507847a..abe574d1 100644
+ # Sequence groups in the REMOTE_PREFILLING state.
+ # Contain requests that are being prefilled by a remote worker.
+ self.remote_prefilling: Deque[SequenceGroup] = deque()
+ # Contain requests that are being prefilled by a local worker.
+ self.prefill_sending: Deque[SequenceGroup] = deque()
+
+ self._remote_prefill_outputs: Dict[str, int] = {}
+
...
...
@@ -438,25 +431,24 @@ index f507847a..abe574d1 100644
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
@@ -501,7 +51
4
,7 @@
class Scheduler:
@@ -501,7 +51
1
,7 @@
class Scheduler:
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
- self.swapped) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0
or len(self.prefill_sending) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +53
6,8
@@
class Scheduler:
@@ -523,6 +53
3,7
@@
class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
+ finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
+ finished_prefills: Optional[Set[str]] = None
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
@@ -537,6 +5
52
,8 @@
class Scheduler:
@@ -537,6 +5
48
,8 @@
class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
...
...
@@ -465,7 +457,7 @@ index f507847a..abe574d1 100644
Returns:
SchedulerRunningOutputs.
@@ -566,6 +5
83,38
@@
class Scheduler:
@@ -566,6 +5
79,24
@@
class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
...
...
@@ -476,7 +468,6 @@ index f507847a..abe574d1 100644
+ if seq_group.request_id not in finished_prefills:
+ leftover_remote_prefilling_sequences.append(seq_group)
+ continue
+
+ else:
+ finished_prefills.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
...
...
@@ -487,63 +478,39 @@ index f507847a..abe574d1 100644
+ seq.data._stage = SequenceStage.DECODE
+ self.running.appendleft(seq_group)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+
+ remote_transfers_queue = self.prefill_sending
+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
+ while remote_transfers_queue:
+ seq_group = remote_transfers_queue.popleft()
+ if seq_group.request_id not in finished_transfers:
+ leftover_remote_transfers_sequences.append(seq_group)
+ else:
+ finished_transfers.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ seq = seq_group.seqs[0]
+ self.free_seq(seq)
+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
+
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
@@ -1008,7 +10
57,1
7 @@
class Scheduler:
@@ -1008,7 +10
39,
7 @@
class Scheduler:
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
- self._allocate_and_set_running(seq_group)
+
+ seq_group_copy = copy.deepcopy(seq_group)
+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
+
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group)
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
+ self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +
1
107,7 @@
class Scheduler:
@@ -1048,7 +107
9
,7 @@
class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
- def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] =
None, finished_transfers: Optional[Set[str]] =
None) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
@@ -1090,7 +11
49,9
@@
class Scheduler:
@@ -1090,7 +11
21,8
@@
class Scheduler:
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
curr_loras,
- enable_chunking=False)
+ enable_chunking=False,
+ finished_prefills=finished_prefills,
+ finished_transfers=finished_transfers)
+ finished_prefills=finished_prefills)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
@@ -1106,7 +11
67
,12 @@
class Scheduler:
@@ -1106,7 +11
38
,12 @@
class Scheduler:
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
...
...
@@ -557,31 +524,30 @@ index f507847a..abe574d1 100644
self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1
314
,14 @@
class Scheduler:
@@ -1248,12 +1
285
,14 @@
class Scheduler:
len(running_scheduled.swapped_out)),
)
- def _schedule(self) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] =
None, finished_transfers: Optional[Set[str]] =
None) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
+ if finished_prefills
or finished_transfers
:
+ if finished_prefills:
+ raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill()
else:
- return self._schedule_default()
+ return self._schedule_default(finished_prefills
, finished_transfers
)
+ return self._schedule_default(finished_prefills)
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
@@ -1287,14 +13
55
,1
6
@@
class Scheduler:
@@ -1287,14 +13
26
,1
5
@@
class Scheduler:
return no_single_seq
def schedule(
- self
+ self,
+ finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
+ finished_prefills: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
...
...
@@ -590,11 +556,11 @@ index f507847a..abe574d1 100644
- scheduler_outputs: SchedulerOutputs = self._schedule()
+ scheduler_start_time = time.perf_counter()
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills
, finished_transfers
)
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills)
now = time.time()
if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1
40
3,8 @@
class Scheduler:
@@ -1333,7 +1
37
3,8 @@
class Scheduler:
encoder_seq_data = None
cross_block_table = None
...
...
@@ -604,24 +570,18 @@ index f507847a..abe574d1 100644
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,
9
+14
3
5,1
6
@@
class Scheduler:
@@ -1364,
6
+14
0
5,1
0
@@
class Scheduler:
< seqs[0].data.get_len()):
do_sample = False
+ is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
+
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
+ logger.debug("Assinged blocks: %s", block_tables)
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@@ -1392,6 +1470,7 @@
class Scheduler:
@@ -1392,6 +1437,7 @@
class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
...
...
@@ -629,7 +589,7 @@ index f507847a..abe574d1 100644
)
else:
# When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +156
9
,13 @@
class Scheduler:
@@ -1490,10 +15
3
6,13 @@
class Scheduler:
self._async_stopped.clear()
...
...
@@ -645,80 +605,12 @@ index f507847a..abe574d1 100644
def _append_slots(self,
seq_group: SequenceGroup,
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644
index 00000000..9b938039
--- /dev/null
+++ b/vllm/distributed/device_communicators/kv_rearrange.py
@@ -0,0 +1,61 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def rearrange_kernel(
+ t1_ptr,
+ t2_ptr,
+ N,
+ B,
+ H,
+ C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+
+ curr_n = offsets // block_size
+ curr_b = offsets // token_size % B
+ curr_h = offsets // C % H
+ curr_c = offsets % C
+
+ src_pos = offsets
+
+ tp_group = curr_h * d // H
+ dst_h = curr_h % (H // d)
+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
+
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset
+
+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
+
+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int):
+ N, B, H, C = t1.shape
+
+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
+ assert H % d == 0, "H must be divisible by d"
+
+ block_size = B * H * C
+ token_size = H * C
+ tensor_size = N * block_size
+ tensor_subset_size = tensor_size // d
+
+ BLOCK_SIZE = 1024
+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
+
+ rearrange_kernel[grid](
+ t1, t2,
+ N, B, H, C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE=BLOCK_SIZE
+ )
\
No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
index 00000000..
f1618bc4
index 00000000..
bc962726
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,
318
@@
@@ -0,0 +1,
249
@@
+import torch
+from typing import List, Tuple
+from vllm.config import VllmConfig
...
...
@@ -726,18 +618,39 @@ index 00000000..f1618bc4
+import msgspec
+import time
+import uuid
+from collections import defaultdict
+from .kv_rearrange import rearrange_tensors
+from nixl_wrapper import nixl_wrapper as NixlWrapper
+
+logger = init_logger(__name__)
+
+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
+try:
+ from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore
+ logger.info("NIXL is available")
+except ImportError:
+ logger.warning("NIXL is not available")
+ NixlWrapper = None
+
+def nixl_wrapper_init_patch(self, agent_name, nixl_config):
+ logger.info("Initializing patched NixlWrapper")
+ import nixl_bindings as nixl
+ # Read available backends and device info from nixl_config
+ # For now setting the multithreading to enabled.
+ devices = nixl.nixlAgentConfig(False)
+ init = nixl.nixlUcxInitParams()
+
+ self.name = agent_name
+ self.notifs = {}
+ self.backends = {}
+ self.agent = nixl.nixlAgent(agent_name, devices)
+ self.backends["UCX"] = self.agent.createBackend(init)
+
+ self.nixl_mems = {"DRAM": nixl.DRAM_SEG,
+ "VRAM": nixl.VRAM_SEG,
+ "cpu": nixl.DRAM_SEG,
+ "cuda": nixl.VRAM_SEG}
+ self.nixl_ops = {"WRITE": nixl.NIXL_WR_FLUSH,
+ "READ": nixl.NIXL_RD_FLUSH,
+ "WRITE_NOTIF": nixl.NIXL_WR_NOTIF,
+ "READ_NOTIF": nixl.NIXL_RD_NOTIF}
+
+ print("Initializied NIXL agent:", agent_name)
+
+NixlWrapper.__init__ = nixl_wrapper_init_patch
+
+
+
+class NixlMetadata(
+ msgspec.Struct,
...
...
@@ -749,20 +662,14 @@ index 00000000..f1618bc4
+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
+
+
+class
Triton
NixlConnector:
+class
Dynemo
NixlConnector:
+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
+ self.vllm_config = vllm_config
+ if NixlWrapper is None:
+ logger.error("NIXL is not available")
+ raise RuntimeError("NIXL is not available")
+ logger.info("Initializing NIXL wrapper")
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+
+ self.num_layers = None
+ self.num_blocks = None
+ self.num_heads = None
+ self.block_len = None
+ self.kv_caches = None
+ self.kv_caches_base_addr = {}
+ self.kv_cache_shape = {}
+
...
...
@@ -771,51 +678,33 @@ index 00000000..f1618bc4
+ self.engine_id = engine_id
+ self.rank = rank
+ self.notifs = {}
+ self._tp_size = {}
+ self._block_descs = {}
+ self._xfer_side_handles = {}
+
+
+ self._transfers = defaultdict(list)
+
+
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+
+
+ @property
+ def agent_name(self):
+ return self.nixl_wrapper.name
+
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ caches_data = []
+ self.num_layers = len(kv_caches)
+ _, _, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ self.num_layers = len(kv_caches)
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+
+ kv_caches_base_addr = []
+ caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr()
+ region_len = num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank))
+ for block_id in range(self.num_blocks):
+ blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+
+ region_len = cache.numel() * cache.element_size()
+ gpu_id = cache.get_device()
+ assert gpu_id > -1, "Tensor is not on GPU"
+ caches_data.append((base_addr, region_len, gpu_id))
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata()
+
...
...
@@ -825,14 +714,10 @@ index 00000000..f1618bc4
+ for agent_name in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp):
+ self._tp_size[engine_id] = agent_tp
+ agent_names = []
+ for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+ def add_remote_agent(self, engine_id, agent_metadata):
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_metadata)
+ self._remote_agents[engine_id] = agent_name
+ return agent_name
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
...
...
@@ -847,29 +732,17 @@ index 00000000..f1618bc4
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids
+
+ def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0):
+ if rank is None:
+ rank = self.rank
+ offset_block_len = self.block_len
+ block_len = self.block_len // tp_multiplier
+ tp_offset = i * block_len
+ else:
+ offset_block_len = self.block_len // tp_multiplier
+ block_len = self.block_len // tp_multiplier
+ tp_offset = 0
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ def _get_range_descs(self, engine_id, ranges, layer_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ blocks_data = []
+ for layer_id in layer_ids:
+ for range_start, range_end in ranges:
+ range_len = range_end - range_start + 1
+ key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ start_offset = range_start * offset_block_len + tp_offset * range_len
+ blocks_len = range_len * block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ key_base_addr, value_base_addr = self.kv_caches_base_addr[engine_id][layer_id]
+ start_offset = range_start * self.block_len
+ blocks_len = (range_end - range_start + 1) * self.block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, self.rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, self.rank))
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+
+ def _get_ranges(self, block_ids):
...
...
@@ -882,9 +755,9 @@ index 00000000..f1618bc4
+ ranges = []
+ for i in range(len(sorted_block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ ranges.append([
sorted_block_ids[i],
sorted_block_ids[i]])
+ ranges.append([sorted_block_ids[i]])
+ else:
+ ranges[-1]
[1] =
sorted_block_ids[i]
+ ranges[-1]
.append(
sorted_block_ids[i]
)
+ return ranges
+
+ def _get_same_length_ranges(self, src_ranges, dst_ranges):
...
...
@@ -927,21 +800,8 @@ index 00000000..f1618bc4
+
+
+
+ def _get_block_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id)
+ return descs_ids
+ def transfer_mem(self, src_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+
+
+
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg, use_prepped_xfer=False):
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
...
...
@@ -950,61 +810,43 @@ index 00000000..f1618bc4
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+
+ if use_prepped_xfer:
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id]
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id]
+
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True)
+ # else:
+ # Legacy path using range-based transfers
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
+ src_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(src_ranges, dst_ranges)
+
+ assert len(src_ranges) == 1
+ assert len(staging_ranges) == 1
+ logger.debug("Got %s overlapping ranges for %s blocks", len(src_overlapping_ranges), len(src_block_ids))
+
+
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
logger.debug("Time to get ranges: %s ms", time.perf_counter() - start_time)
+
+ src_range_start, src_range_end = src_ranges[0]
+ src_range_len = src_range_end - src_range_start + 1
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+ src_descs = self._get_range_descs(self.engine_id, src_overlapping_ranges, "all")
+ dst_descs = self._get_range_descs(dst_engine_id, dst_overlapping_ranges, "all")
+
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+
+ staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+
+ for i in range(tp_multiplier):
+
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg, "WRITE")
+ self._transfers[notify_msg].append(handle)
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs, self._remote_agents[dst_engine_id], notify_msg, "WRITE")
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ # TODO ptarasiewicz: remove blocking transfer mem
+ # add scheduler check for transfer done
+ while True:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "ERR":
+ raise RuntimeError("Transfer failed")
+ elif xfer_state == "DONE":
+ logger.debug("Transfer done")
+ break
+ elif xfer_state == "PROC":
+ time.sleep(0.01)
+ else:
+ raise RuntimeError("Unknown transfer state")
+ logger.debug("Time to wait for transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ self.nixl_wrapper.abort_xfer(handle)
+ logger.debug("Time to abort xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer time: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
...
...
@@ -1018,26 +860,6 @@ index 00000000..f1618bc4
+
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
+
+ def get_done_tranfers(self) -> List[str]:
+ done_req_ids = []
+ for req_id, handles in self._transfers.items():
+ running_reqs = []
+ for handle in handles:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "DONE":
+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ continue
+ if xfer_state == "PROC":
+ running_reqs.append(handle)
+ else:
+ raise RuntimeError("Transfer failed with state %s", xfer_state)
+ if len(running_reqs) == 0:
+ done_req_ids.append(req_id)
+ else:
+ self._transfers[req_id] = running_reqs
+ return done_req_ids
\
No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
...
...
@@ -1064,9 +886,9 @@ index fe480533..61a357d0 100644
"SimpleConnector")
+
+KVConnectorFactory.register_connector(
+ "
Triton
NcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.
triton
_connector",
+ "
Triton
Connector")
+ "
Dynemo
NcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.
dynemo
_connector",
+ "
Dynemo
Connector")
\
No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..e33919c1 100644
...
...
@@ -1396,7 +1218,7 @@ index 2033e976..e33919c1 100644
+ world_group.broadcast_object(kv_config_enhanced)
+
+ else:
+ raise NotImplementedError("MooncakeConnector is not supported in
Triton
Distributed vllm patch")
+ raise NotImplementedError("MooncakeConnector is not supported in
Dynemo
Distributed vllm patch")
+ else:
+ kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
...
...
@@ -1407,11 +1229,11 @@ index 2033e976..e33919c1 100644
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\
No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/
triton
_connector.py b/vllm/distributed/kv_transfer/kv_connector/
triton
_connector.py
diff --git a/vllm/distributed/kv_transfer/kv_connector/
dynemo
_connector.py b/vllm/distributed/kv_transfer/kv_connector/
dynemo
_connector.py
new file mode 100644
index 00000000..cb3b3660
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/
triton
_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/
dynemo
_connector.py
@@ -0,0 +1,350 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
...
...
@@ -1443,7 +1265,7 @@ index 00000000..cb3b3660
+logger = init_logger(__name__)
+
+
+class
Triton
Connector(KVConnectorBase):
+class
Dynemo
Connector(KVConnectorBase):
+
+ def __init__(
+ self,
...
...
@@ -1457,16 +1279,16 @@ index 00000000..cb3b3660
+ self.tp_size = config.parallel_config.tensor_parallel_size
+ self.rank = rank
+
+ if self.config.kv_connector != "
Triton
NcclConnector":
+ raise NotImplementedError("Only
Triton
NcclConnector is supported by the
Triton
Connector class")
+ if self.config.kv_connector != "
Dynemo
NcclConnector":
+ raise NotImplementedError("Only
Dynemo
NcclConnector is supported by the
Dynemo
Connector class")
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.
triton
_nccl_pipe import (
+
Triton
NcclDataPlane)
+ from vllm.distributed.kv_transfer.kv_pipe.
dynemo
_nccl_pipe import (
+
Dynemo
NcclDataPlane)
+
+ logger.info(
+ "Initializing
Triton
NcclConnector under kv_transfer_config %s",
+ "Initializing
Dynemo
NcclConnector under kv_transfer_config %s",
+ self.config)
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
...
...
@@ -1498,7 +1320,7 @@ index 00000000..cb3b3660
+ port_offset=port_offset_base,
+ )
+
+ self.data_plane =
Triton
NcclDataPlane(
+ self.data_plane =
Dynemo
NcclDataPlane(
+ data_pipe=self.data_pipe,
+ port=self._get_data_plane_port(self.global_kv_rank),
+ )
...
...
@@ -2233,11 +2055,11 @@ index 7aa53d07..f5dd50b7 100644
def close(self):
"""
diff --git a/vllm/distributed/kv_transfer/kv_pipe/
triton
_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/
triton
_nccl_pipe.py
diff --git a/vllm/distributed/kv_transfer/kv_pipe/
dynemo
_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/
dynemo
_nccl_pipe.py
new file mode 100644
index 00000000..8a356504
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/
triton
_nccl_pipe.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/
dynemo
_nccl_pipe.py
@@ -0,0 +1,124 @@
+import logging
+import threading
...
...
@@ -2253,7 +2075,7 @@ index 00000000..8a356504
+logger = logging.getLogger(__name__)
+
+
+class
Triton
NcclDataPlane:
+class
Dynemo
NcclDataPlane:
+ def __init__(
+ self,
+ data_pipe: PyNcclPipe,
...
...
@@ -2399,7 +2221,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..
254337cb
100644
index d82d9ad9..
9ba1a326
100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@
...
...
@@ -2482,13 +2304,11 @@ index d82d9ad9..254337cb 100644
+ self.engine_id = str(uuid.uuid4())
+ self._nixl_agents_names: Optional[List[str]] = None
+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "
Triton
NixlConnector":
+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "
Dynemo
NixlConnector":
+ self._nixl_agents_names = self._initialize_nixl()
+
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._finished_prefills = set()
+ self._finished_transfers = set()
+
+ @property
+ def is_nixl_initialized(self) -> bool:
...
...
@@ -2507,6 +2327,8 @@ index d82d9ad9..254337cb 100644
+ engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ if len(agents_metadata) != len(self._nixl_agents_names):
+ raise ValueError("Number of agents does not match. Make sure all engines are initialized with the same parallel sizes.")
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+
+ def _initialize_nixl(self) -> List[bytes]:
...
...
@@ -2540,16 +2362,7 @@ index d82d9ad9..254337cb 100644
ParallelSampleSequenceGroup.add_request(
request_id,
self,
@@ -574,6 +624,8 @@
class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
+ next(self.seq_counter) # empty sequence for staging
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +636,7 @@
class LLMEngine:
@@ -584,7 +634,7 @@
class LLMEngine:
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
...
...
@@ -2558,7 +2371,7 @@ index d82d9ad9..254337cb 100644
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@@ -601,8 +65
3
,12 @@
class LLMEngine:
@@ -601,8 +65
1
,12 @@
class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
...
...
@@ -2572,7 +2385,7 @@ index d82d9ad9..254337cb 100644
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
@@ -673,6 +72
9
,7 @@
class LLMEngine:
@@ -673,6 +72
7
,7 @@
class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
...
...
@@ -2580,7 +2393,7 @@ index d82d9ad9..254337cb 100644
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@@ -765,6 +82
2
,7 @@
class LLMEngine:
@@ -765,6 +82
0
,7 @@
class LLMEngine:
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
...
...
@@ -2588,7 +2401,7 @@ index d82d9ad9..254337cb 100644
)
def _validate_token_prompt(self, prompt: PromptType,
@@ -799,6 +85
7
,7 @@
class LLMEngine:
@@ -799,6 +85
5
,7 @@
class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
...
...
@@ -2596,7 +2409,7 @@ index d82d9ad9..254337cb 100644
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@@ -829,7 +88
8
,9 @@
class LLMEngine:
@@ -829,7 +88
6
,9 @@
class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
...
...
@@ -2607,7 +2420,7 @@ index d82d9ad9..254337cb 100644
return seq_group
@@ -995,11 +105
6
,11 @@
class LLMEngine:
@@ -995,11 +105
4
,11 @@
class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
...
...
@@ -2621,7 +2434,7 @@ index d82d9ad9..254337cb 100644
# Sanity check
assert len(seq_group_metadata_list) == len(
@@ -1325,15 +138
6
,49 @@
class LLMEngine:
@@ -1325,15 +138
4
,49 @@
class LLMEngine:
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
...
...
@@ -2641,7 +2454,7 @@ index d82d9ad9..254337cb 100644
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule()
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills
, self._finished_transfers
)
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills)
+
+
+ # Separate remote prefill and running seq groups
...
...
@@ -2673,7 +2486,7 @@ index d82d9ad9..254337cb 100644
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +147
8,31
@@
class LLMEngine:
@@ -1383,9 +147
6,29
@@
class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
...
...
@@ -2687,11 +2500,9 @@ index d82d9ad9..254337cb 100644
+ req_id = scheduled_seq_group.seq_group.request_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id]
+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
+ memory_transfer_req = MemoryTransferRequest(
+ request_id=req_id,
+ src_block_ids=block_table,
+ staging_block_ids=staging_block_ids,
+ dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id,
...
...
@@ -2701,13 +2512,13 @@ index d82d9ad9..254337cb 100644
+
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs
+
+ outputs, request_notif_counter
, request_done_counter
= self.model_executor.execute_model(
+ outputs, request_notif_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req)
-
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
@@ -1396,7 +15
13
,2
6
@@
class LLMEngine:
@@ -1396,7 +15
09
,2
0
@@
class LLMEngine:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
...
...
@@ -2718,7 +2529,7 @@ index d82d9ad9..254337cb 100644
+ blocks_to_swap_out=[],
+ blocks_to_copy=[])
+
+ outputs, request_notif_counter
, request_done_counter
= self.model_executor.execute_model(
+ outputs, request_notif_counter = self.model_executor.execute_model(
+ execute_model_req=execute_model_req)
+
+ for req_id, notif_count in request_notif_counter.items():
...
...
@@ -2726,16 +2537,10 @@ index d82d9ad9..254337cb 100644
+ if self._request_notif_counter[req_id] > -1:
+ self._finished_prefills.add(req_id)
+ del self._request_notif_counter[req_id]
+
+ for req_id, done_count in request_done_counter.items():
+ self._request_done_counter[req_id] += done_count
+ if self._request_done_counter[req_id] > -1:
+ self._finished_transfers.add(req_id)
+ del self._request_done_counter[req_id]
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@@ -1456,7 +15
9
2,7 @@
class LLMEngine:
@@ -1456,7 +15
8
2,7 @@
class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
...
...
@@ -2813,7 +2618,7 @@ index 3cf1850e..6b90ece7 100644
+ kv_active_blocks: int
+ kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..
c501e4c8
100644
index 85b5f31e..
d33d546a
100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...
...
@@ -2895,7 +2700,7 @@ index 85b5f31e..c501e4c8 100644
+
+ @property
+ def using_nixl_connector(self) -> bool:
+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "
Triton
NixlConnector"
+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "
Dynemo
NixlConnector"
+
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
...
...
@@ -3346,10 +3151,10 @@ index 786380c3..56a7cf89 100644
"""The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644
index 00000000..
957f55de
index 00000000..
03f02006
--- /dev/null
+++ b/vllm/remote_prefill.py
@@ -0,0 +1,5
4
@@
@@ -0,0 +1,5
3
@@
+from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine
+
...
...
@@ -3387,7 +3192,6 @@ index 00000000..957f55de
+ """
+ request_id: str
+ src_block_ids: List[int]
+ staging_block_ids: List[int]
+ dst_block_ids: List[int]
+ dst_engine_id: str
+ notify_msg: str
...
...
@@ -3531,7 +3335,7 @@ index 12baecde..cbada27f 100644
if self.vllm_config.kv_transfer_config is None:
return False
+
+ if self.vllm_config.kv_transfer_config.kv_connector == "
Triton
NixlConnector":
+ if self.vllm_config.kv_transfer_config.kv_connector == "
Dynemo
NixlConnector":
+ return False
prefill_meta = model_input.attn_metadata.prefill_metadata
...
...
@@ -3541,13 +3345,13 @@ index 12baecde..cbada27f 100644
if self.vllm_config.kv_transfer_config is None:
return False
+
+ if self.vllm_config.kv_transfer_config.kv_connector == "
Triton
NixlConnector":
+ if self.vllm_config.kv_transfer_config.kv_connector == "
Dynemo
NixlConnector":
+ return False
prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..
1b8515bf
100644
index 582aa460..
ffb7b403
100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@
...
...
@@ -3563,7 +3367,7 @@ index 582aa460..1b8515bf 100644
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
+from vllm.distributed.device_communicators.nixl import
Triton
NixlConnector
+from vllm.distributed.device_communicators.nixl import
Dynemo
NixlConnector
+
logger = init_logger(__name__)
...
...
@@ -3577,7 +3381,7 @@ index 582aa460..1b8515bf 100644
+ # TODO ptarasiewicz nixl can also support DRAM
+ assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector"
+
+ self.nixl_connector =
Triton
NixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector =
Dynemo
NixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
+ assert len(self.cache_engine) == 1, "Only one cache engine is supported for now"
+ self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache)
+ return self.nixl_connector.agent_name
...
...
@@ -3588,13 +3392,13 @@ index 582aa460..1b8515bf 100644
+
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata
, len(agents_metadata)
) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata
[self.local_rank]
) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr
[self.local_rank]
)
+ return agent_name
+
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name
[self.local_rank]
, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
...
...
@@ -3602,8 +3406,8 @@ index 582aa460..1b8515bf 100644
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ if worker_input.src_block_ids is not None:
+ for src_block_ids,
staging_block_ids,
dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids,
worker_input.staging_block_ids,
worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids,
staging_block_ids,
dst_block_ids, dst_engine_id, notify_msg)
+ for src_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+
+ def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
...
...
@@ -3621,12 +3425,11 @@ index 582aa460..1b8515bf 100644
return WorkerInput(
num_seq_groups=num_seq_groups,
@@ -375,6 +416,1
1
@@
class Worker(LocalOrDistributedWorkerBase):
@@ -375,6 +416,1
0
@@
class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs],
...
...
@@ -3634,7 +3437,7 @@ index 582aa460..1b8515bf 100644
@torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..
ecb68530
100644
index 819b81fb..
d9c039eb
100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
...
...
@@ -3649,7 +3452,7 @@ index 819b81fb..ecb68530 100644
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
+from vllm.distributed.device_communicators.nixl import
Triton
NixlConnector
+from vllm.distributed.device_communicators.nixl import
Dynemo
NixlConnector
logger = init_logger(__name__)
...
...
@@ -3657,17 +3460,16 @@ index 819b81fb..ecb68530 100644
from vllm.platforms import current_platform
self.current_platform = current_platform
+ self.nixl_connector: Optional[
Triton
NixlConnector] = None
+ self.nixl_connector: Optional[
Dynemo
NixlConnector] = None
+
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,1
2
@@
class WorkerInput:
@@ -216,6 +220,1
1
@@
class WorkerInput:
virtual_engine: int = 0
num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None
+ staging_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None
...
...
@@ -3675,31 +3477,29 @@ index 819b81fb..ecb68530 100644
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
@@ -232,6 +24
2
,1
1
@@
class WorkerInput:
@@ -232,6 +24
1
,1
0
@@
class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"),
+ staging_block_ids=tensor_dict.pop("staging_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"),
)
def as_broadcastable_tensor_dict(
@@ -246,6 +2
61
,1
1
@@
class WorkerInput:
@@ -246,6 +2
59
,1
0
@@
class WorkerInput:
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids,
+ "staging_block_ids": self.staging_block_ids,
+ "dst_block_ids": self.dst_block_ids,
+ "dst_engine_id": self.dst_engine_id,
+ "notify_msg": self.notify_msg,
}
return tensor_dict
@@ -316,13 +33
6
,16 @@
class LocalOrDistributedWorkerBase(WorkerBase):
@@ -316,13 +33
3
,16 @@
class LocalOrDistributedWorkerBase(WorkerBase):
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
...
...
@@ -3721,7 +3521,7 @@ index 819b81fb..ecb68530 100644
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
@@ -396,49 +41
9,87
@@
class LocalOrDistributedWorkerBase(WorkerBase):
@@ -396,49 +41
6,79
@@
class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
...
...
@@ -3818,7 +3618,7 @@ index 819b81fb..ecb68530 100644
+ else:
+ for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs:
+ for req_ids in notifs.values():
...
...
@@ -3827,20 +3627,12 @@ index 819b81fb..ecb68530 100644
+
+ if request_notif_counter:
+ logger.debug("Request notif counter: %s", request_notif_counter)
+
+ request_done_counter = defaultdict(int)
+ for req_id in self.nixl_connector.get_done_tranfers():
+ request_done_counter[req_id] += 1
+
+ if request_done_counter:
+ logger.debug("Request done counter: %s", request_done_counter)
+
+ else:
+ request_notif_counter = {}
+ request_done_counter = {}
# output is List[SamplerOutput]
- return output
+ return output, request_notif_counter
, request_done_counter
+ return output, request_notif_counter
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ pass
...
...
deploy/compoundai/sdk/src/compoundai/cli/serve_nova.py
View file @
1af7433b
...
...
@@ -26,7 +26,8 @@ import typing as t
from
typing
import
Any
import
click
from
triton_distributed_rs
import
DistributedRuntime
,
triton_endpoint
,
triton_worker
from
dynemo.runtime
import
DistributedRuntime
,
dynemo_endpoint
,
dynemo_worker
logger
=
logging
.
getLogger
(
"compoundai.serve.nova"
)
...
...
@@ -102,7 +103,7 @@ def main(
server_context
.
worker_index
=
worker_id
class_instance
=
service
.
inner
()
@
triton
_worker
()
@
dynemo
_worker
()
async
def
worker
(
runtime
:
DistributedRuntime
):
if
service_name
and
service_name
!=
service
.
name
:
server_context
.
service_type
=
"service"
...
...
@@ -157,12 +158,12 @@ def main(
# Bind an instance of inner to the endpoint
bound_method
=
endpoint
.
func
.
__get__
(
class_instance
)
# Only pass request type for now, use Any for response
# TODO: Handle a
triton
_endpoint not having types
# TODO: Handle a
dynemo
_endpoint not having types
# TODO: Handle multiple endpoints in a single component
triton
_wrapped_method
=
triton
_endpoint
(
endpoint
.
request_type
,
Any
)(
dynemo
_wrapped_method
=
dynemo
_endpoint
(
endpoint
.
request_type
,
Any
)(
bound_method
)
result
=
await
td_endpoint
.
serve_endpoint
(
triton
_wrapped_method
)
result
=
await
td_endpoint
.
serve_endpoint
(
dynemo
_wrapped_method
)
# WARNING: unreachable code :( because serve blocks
logger
.
info
(
f
"[
{
run_id
}
] Result:
{
result
}
"
)
logger
.
info
(
f
"[
{
run_id
}
] Registered endpoint '
{
name
}
'"
)
...
...
deploy/compoundai/sdk/src/compoundai/sdk/decorators.py
View file @
1af7433b
...
...
@@ -50,7 +50,7 @@ class NovaEndpoint:
if
isinstance
(
args
[
1
],
(
str
,
dict
)):
args
[
1
]
=
self
.
request_type
.
parse_obj
(
args
[
1
])
# type: ignore
# Convert Pydantic model to dict before passing to
triton
# Convert Pydantic model to dict before passing to
dynemo
if
len
(
args
)
>
1
and
isinstance
(
args
[
1
],
BaseModel
):
args
=
list
(
args
)
# type: ignore
args
[
1
]
=
args
[
1
].
model_dump
()
# type: ignore
...
...
deploy/compoundai/sdk/src/compoundai/sdk/dependency.py
View file @
1af7433b
...
...
@@ -72,9 +72,9 @@ class NovaClient:
else
:
# Create nova worker if no runtime
from
triton_distributed_rs
import
DistributedRuntime
,
triton
_worker
from
dynemo.runtime
import
DistributedRuntime
,
dynemo
_worker
@
triton
_worker
()
@
dynemo
_worker
()
async
def
stream_worker
(
runtime
:
DistributedRuntime
):
try
:
# Store runtime for future use
...
...
examples/python_rs/llm/tensorrt_llm/README.md
View file @
1af7433b
...
...
@@ -90,14 +90,14 @@ Note: NATS and ETCD servers should be running and accessible from the container
Run the server logging (with debug level logging):
```
bash
TRD
_LOG
=
DEBUG http &
DYN
_LOG
=
DEBUG http &
```
By default the server will run on port 8080.
Add model to the server:
```
bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0
triton-init
.tensorrt-llm.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0
triton-init
.tensorrt-llm.completions
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0
dynemo
.tensorrt-llm.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0
dynemo
.tensorrt-llm.completions
```
#### 2. Workers
...
...
@@ -214,14 +214,14 @@ Run the container interactively with the following command:
Run the server logging (with debug level logging):
```
bash
TRD
_LOG
=
DEBUG http &
DYN
_LOG
=
DEBUG http &
```
By default the server will run on port 8080.
Add model to the server:
```
bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0
triton-init
.router.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0
triton-init
.router.completions
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0
dynemo
.router.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0
dynemo
.router.completions
```
#### 2. Workers
...
...
examples/python_rs/llm/tensorrt_llm/common/client.py
View file @
1af7433b
...
...
@@ -19,12 +19,12 @@ import asyncio
import
uvloop
from
triton_distributed
.runtime
import
DistributedRuntime
,
triton
_worker
from
dynemo
.runtime
import
DistributedRuntime
,
dynemo
_worker
from
.protocol
import
Request
@
triton
_worker
()
@
dynemo
_worker
()
async
def
worker
(
runtime
:
DistributedRuntime
,
component
:
str
,
...
...
@@ -38,7 +38,7 @@ async def worker(
"""
# create client
client
=
(
await
runtime
.
namespace
(
"
triton-init
"
)
await
runtime
.
namespace
(
"
dynemo
"
)
.
component
(
component
)
.
endpoint
(
"generate"
)
.
client
()
...
...
Prev
1
2
3
4
5
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment