"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "abc803e28826b8d2c4817f2c44103e1c1d57a8c2"
Commit 1af7433b authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: rename triton_distributed to dynemo (#22)


Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
parent ee4ef06b
......@@ -19,7 +19,7 @@
**/*.plan
**/.cache/*
**/*onnx*
# Engine must be allowed because code contains triton_distributed_engine.py
# Engine must be allowed because code contains dynemo_engine.py
**/*tensorrtllm_engines*
**/*tensorrtllm_models*
**/*tensorrtllm_checkpoints*
......
......@@ -22,25 +22,6 @@ on:
jobs:
# icp_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: ./icp/protos/gen_python.sh
# - run: pytest --verbose icp
# timeout-minutes: 3
pre-commit:
runs-on: ubuntu-latest
permissions:
......@@ -52,41 +33,3 @@ jobs:
timeout-minutes: 3
# providers_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# PROTO_OUT: /python/icp/protos
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: pytest --verbose providers
# worker_validation:
# runs-on: ubuntu-latest
# container:
# image: ghcr.io/triton-inference-server/triton3/python_ci:0.1.9
# env:
# BUILD_NUMBER: ${{ github.job }}
# CUDA_VISIBLE_DEVICES: -1
# PATH: /opt/tritonserver/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/ucx/bin:/bin:/sbin:/usr/bin:/usr/sbin:/usr/local/bin:/usr/local/mpi/bin:/usr/local/sbin
# PROTO_OUT: /python/icp/protos
# volumes:
# - ${{ github.workspace }}:/workspace
# permissions:
# contents: read
# packages: read
# steps:
# - uses: actions/checkout@v4
# - run: ./icp/protos/gen_python.sh
# - run: pytest -p no:warnings --verbose worker/python/tests
# timeout-minutes: 2
......@@ -17,7 +17,7 @@ limitations under the License.
# Open Source License Attribution
Triton Distributed uses Open Source components. You can find the details of these open-source projects along with license information below.
Dynemo uses Open Source components. You can find the details of these open-source projects along with license information below.
We are grateful to the developers for their contributions to open source and acknowledge these below.
## nats-py - [Apache License 2.0](https://github.com/nats-io/nats.py/blob/main/LICENSE)
......
......@@ -71,7 +71,7 @@ The run script offers a few common workflows:
1. Running a command in a container and exiting.
```
./container/run.sh -- python3 -c "import triton_distributed.runtime; help(triton_distributed.runtime)"
./container/run.sh -- python3 -c "import dynemo.runtime; help(dynemo.runtime)"
```
2. Starting an interactive shell.
......
......@@ -737,6 +737,8 @@ version = "0.1.0"
dependencies = [
"axum 0.6.20",
"clap",
"dynemo-llm",
"dynemo-runtime",
"opentelemetry",
"opentelemetry-prometheus",
"prometheus",
......@@ -747,8 +749,6 @@ dependencies = [
"thiserror 1.0.69",
"tokio",
"tracing",
"triton-distributed-llm",
"triton-distributed-runtime",
]
[[package]]
......@@ -1024,6 +1024,99 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "dynemo-llm"
version = "0.2.1"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
"bindgen",
"blake3",
"bs62",
"bytes",
"chrono",
"cmake",
"derive_builder",
"dynemo-runtime",
"either",
"erased-serde",
"futures",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"serde_repr",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"unicode-segmentation",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "dynemo-runtime"
version = "0.2.1"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "ed25519"
version = "2.2.3"
......@@ -4232,99 +4325,6 @@ dependencies = [
"tracing-serde",
]
[[package]]
name = "triton-distributed-llm"
version = "0.2.1"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"axum 0.8.1",
"bindgen",
"blake3",
"bs62",
"bytes",
"chrono",
"cmake",
"derive_builder",
"either",
"erased-serde",
"futures",
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"serde_repr",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
"tokio-stream",
"tokio-util",
"toktrie",
"toktrie_hf_tokenizers",
"tracing",
"triton-distributed-runtime",
"unicode-segmentation",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "triton-distributed-runtime"
version = "0.2.1"
dependencies = [
"anyhow",
"async-nats",
"async-once-cell",
"async-stream",
"async-trait",
"async_zmq",
"blake3",
"bytes",
"chrono",
"derive-getters",
"derive_builder",
"educe",
"either",
"etcd-client",
"figment",
"futures",
"humantime",
"local-ip-address",
"log",
"nid",
"nix",
"nuid",
"once_cell",
"prometheus",
"rand",
"regex",
"serde",
"serde_json",
"socket2",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
"uuid",
"validator",
"xxhash-rust",
]
[[package]]
name = "try-lock"
version = "0.2.5"
......
......@@ -21,8 +21,8 @@ license = "Apache-2.0"
[dependencies]
# local
triton-distributed-runtime = { path = "../../../lib/runtime" }
triton-distributed-llm = { path = "../../../lib/llm" }
dynemo-runtime = { path = "../../../lib/runtime" }
dynemo-llm = { path = "../../../lib/llm" }
# workspace - todo
......@@ -40,4 +40,4 @@ rand = "0.8"
axum = "0.6"
[dev-dependencies]
reqwest = { version = "0.11", features = ["blocking"] }
\ No newline at end of file
reqwest = { version = "0.11", features = ["blocking"] }
......@@ -8,17 +8,17 @@ the services associated with that endpoint, do some postprocessing on them,
and then publish an event with the postprocessed data.
```bash
# For more details, try TRD_LOG=debug
TRD_LOG=info cargo run --bin count -- --namespace triton-init --component backend --endpoint generate
# For more details, try DYN_LOG=debug
DYN_LOG=info cargo run --bin count -- --namespace dynemo --component backend --endpoint generate
# 2025-02-26T18:45:05.467026Z INFO count: Creating unique instance of Count at triton-init/components/count/instance
# 2025-02-26T18:45:05.472146Z INFO count: Scraping service triton_init_backend_720278f8 and filtering on subject triton_init_backend_720278f8.generate
# 2025-02-26T18:45:05.467026Z INFO count: Creating unique instance of Count at dynemo/components/count/instance
# 2025-02-26T18:45:05.472146Z INFO count: Scraping service dynemo_init_backend_720278f8 and filtering on subject dynemo_init_backend_720278f8.generate
# ...
```
With no matching endpoints running, you should see warnings in the logs:
```bash
2025-02-26T18:45:06.474161Z WARN count: No endpoints found matching subject triton_init_backend_720278f8.generate
2025-02-26T18:45:06.474161Z WARN count: No endpoints found matching subject dynemo_init_backend_720278f8.generate
```
To see metrics published to a matching endpoint, you can use the
......@@ -35,7 +35,7 @@ since the endpoint will automatically get discovered.
When stats are found from the target endpoints being listened on, count will
aggregate and publish some metrics as both an event and to a prometheus web server:
```
2025-02-28T04:05:58.077901Z INFO count: Aggregated metrics: ProcessedEndpoints { endpoints: [Endpoint { name: "worker-7587884888253033398", subject: "triton_init_backend_720278f8.generate-694d951a80e06bb6", data: ForwardPassMetrics { request_active_slots: 58, request_total_slots: 100, kv_active_blocks: 77, kv_total_blocks: 100 } }, Endpoint { name: "worker-7587884888253033401", subject: "triton_init_backend_720278f8.generate-694d951a80e06bb9", data: ForwardPassMetrics { request_active_slots: 71, request_total_slots: 100, kv_active_blocks: 29, kv_total_blocks: 100 } }], worker_ids: [7587884888253033398, 7587884888253033401], load_avg: 53.0, load_std: 24.0 }
2025-02-28T04:05:58.077901Z INFO count: Aggregated metrics: ProcessedEndpoints { endpoints: [Endpoint { name: "worker-7587884888253033398", subject: "dynemo_init_backend_720278f8.generate-694d951a80e06bb6", data: ForwardPassMetrics { request_active_slots: 58, request_total_slots: 100, kv_active_blocks: 77, kv_total_blocks: 100 } }, Endpoint { name: "worker-7587884888253033401", subject: "dynemo_init_backend_720278f8.generate-694d951a80e06bb9", data: ForwardPassMetrics { request_active_slots: 71, request_total_slots: 100, kv_active_blocks: 29, kv_total_blocks: 100 } }], worker_ids: [7587884888253033398, 7587884888253033401], load_avg: 53.0, load_std: 24.0 }
```
To see the metrics being published in prometheus format, you can run:
......
......@@ -13,10 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use rand::Rng;
use std::sync::Arc;
use triton_distributed_llm::kv_router::protocols::ForwardPassMetrics;
use triton_distributed_runtime::{
use dynemo_llm::kv_router::protocols::ForwardPassMetrics;
use dynemo_runtime::{
logging,
pipeline::{
async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut,
......@@ -25,6 +23,8 @@ use triton_distributed_runtime::{
protocols::annotated::Annotated,
stream, DistributedRuntime, Result, Runtime, Worker,
};
use rand::Rng;
use std::sync::Arc;
fn main() -> Result<()> {
logging::init();
......@@ -69,7 +69,7 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
// we must first create a service, then we can attach one more more endpoints
runtime
.namespace("triton-init")?
.namespace("dynemo")?
.component("backend")?
.service_builder()
.create()
......
......@@ -20,13 +20,11 @@ use prometheus::register_gauge_vec;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use triton_distributed_llm::kv_router::protocols::ForwardPassMetrics;
use triton_distributed_llm::kv_router::scheduler::Endpoint;
use triton_distributed_llm::kv_router::scoring::ProcessedEndpoints;
use dynemo_llm::kv_router::protocols::ForwardPassMetrics;
use dynemo_llm::kv_router::scheduler::Endpoint;
use dynemo_llm::kv_router::scoring::ProcessedEndpoints;
use triton_distributed_runtime::{
distributed::Component, service::EndpointInfo, utils::Duration, Result,
};
use dynemo_runtime::{distributed::Component, service::EndpointInfo, utils::Duration, Result};
/// Configuration for LLM worker load capacity metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
......
......@@ -24,7 +24,7 @@
//! - KV Cache Blocks: [Active, Total]
use clap::Parser;
use triton_distributed_runtime::{
use dynemo_runtime::{
error, logging,
traits::events::EventPublisher,
utils::{Duration, Instant},
......@@ -50,7 +50,7 @@ struct Args {
endpoint: String,
/// Namespace to operate in
#[arg(long, env = "TRD_NAMESPACE", default_value = "triton-init")]
#[arg(long, env = "DYN_NAMESPACE", default_value = "dynemo")]
namespace: String,
/// Polling interval in seconds (minimum 1 second)
......@@ -155,7 +155,7 @@ mod tests {
#[test]
fn test_namespace_from_env() {
env::set_var("TRD_NAMESPACE", "test-namespace");
env::set_var("DYN_NAMESPACE", "test-namespace");
let args = Args::parse_from(["count", "--component", "comp", "--endpoint", "end"]);
assert_eq!(args.namespace, "test-namespace");
}
......
......@@ -16,7 +16,7 @@
ARG BASE_IMAGE="nvcr.io/nvidia/tritonserver"
ARG BASE_IMAGE_TAG="25.01-py3"
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS triton-distributed
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dynemo
# TODO: non root user by default
......@@ -34,7 +34,7 @@ RUN rustup toolchain install 1.85.0-x86_64-unknown-linux-gnu
# Install OpenAI-compatible frontend and its dependencies from triton server
# repository. These are used to have a consistent interface, schema, and FastAPI
# app between Triton Core and Triton Distributed implementations.
# app between Triton Core and Dynemo implementations.
ARG OPENAI_SERVER_TAG="r25.01"
RUN mkdir -p /opt/tritonserver/python && \
cd /opt/tritonserver/python && \
......@@ -78,7 +78,7 @@ ARG TENSORRTLLM_SKIP_CLONE=
ENV FRAMEWORK=${FRAMEWORK}
RUN --mount=type=bind,source=./container/deps/requirements.tensorrtllm.txt,target=/tmp/requirements.txt \
--mount=type=bind,source=./container/deps/clone_tensorrtllm.sh,target=/tmp/clone_tensorrtllm.sh \
if [[ "$FRAMEWORK" == "TENSORRTLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt; if [ ${TENSORRTLLM_SKIP_CLONE} -ne 1 ] ; then /tmp/clone_tensorrtllm.sh --tensorrtllm-backend-repo-tag ${TENSORRTLLM_BACKEND_REPO_TAG} --tensorrtllm-backend-rebuild ${TENSORRTLLM_BACKEND_REBUILD} --triton-llm-path /opt/triton/llm_binding ; fi ; fi
if [[ "$FRAMEWORK" == "TENSORRTLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt; if [ ${TENSORRTLLM_SKIP_CLONE} -ne 1 ] ; then /tmp/clone_tensorrtllm.sh --tensorrtllm-backend-repo-tag ${TENSORRTLLM_BACKEND_REPO_TAG} --tensorrtllm-backend-rebuild ${TENSORRTLLM_BACKEND_REBUILD} --triton-llm-path /opt/dynemo/llm_binding ; fi ; fi
RUN --mount=type=bind,source=./container/deps/requirements.standard.txt,target=/tmp/requirements.txt \
......@@ -106,7 +106,7 @@ ENV VLLM_GENERATE_WORKERS=${VLLM_FRAMEWORK:+1}
ENV VLLM_BASELINE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_CONTEXT_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_GENERATE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_KV_CAPI_PATH="/opt/triton/llm_binding/lib/libtriton_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/dynemo/llm_binding/lib/libdynemo_llm_capi.so"
ENV PYTHONUNBUFFERED=1
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
......@@ -159,27 +159,27 @@ COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c/ && \
cargo build --release --locked && cargo doc --no-deps
# Install uv, create virtualenv for general use, and build triton_distributed wheel
# Install uv, create virtualenv for general use, and build dynemo wheel
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12 && \
source /opt/triton/venv/bin/activate && \
RUN mkdir /opt/dynemo && \
uv venv /opt/dynemo/venv --python 3.12 && \
source /opt/dynemo/venv/bin/activate && \
uv build --wheel --out-dir /workspace/dist && \
uv pip install /workspace/dist/triton_distributed*cp312*.whl
uv pip install /workspace/dist/dynemo*cp312*.whl
# Package the bindings
RUN mkdir -p /opt/triton/bindings/wheels && \
mkdir /opt/triton/bindings/lib && \
cp dist/triton_distributed*cp312*.whl /opt/triton/bindings/wheels/. && \
cp lib/bindings/c/target/release/libtriton_distributed_llm_capi.so /opt/triton/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/triton/bindings/.
RUN mkdir -p /opt/dynemo/bindings/wheels && \
mkdir /opt/dynemo/bindings/lib && \
cp dist/dynemo*cp312*.whl /opt/dynemo/bindings/wheels/. && \
cp lib/bindings/c/target/release/libdynemo_llm_capi.so /opt/dynemo/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/dynemo/bindings/.
# Install triton_distributed_runtime and triton_distributed_llm wheels globally in container for tests that
# Install dynemo.runtime and dynemo.llm wheels globally in container for tests that
# currently run without virtual environment activated.
# TODO: In future, we may use a virtualenv for everything and remove this.
RUN pip install /opt/triton/bindings/wheels/triton_distributed*cp312*.whl
RUN pip install /opt/dynemo/bindings/wheels/dynemo*cp312*.whl
# Copy everything in after install steps to avoid re-running build/install
# Copy everything in after ginstall steps to avoid re-running build/install
# commands on unrelated changes in other dirs.
COPY . /workspace
......
......@@ -24,17 +24,17 @@ ENV PATH=/usr/local/bin/etcd/:$PATH
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
RUN mkdir /opt/dynemo && \
uv venv /opt/dynemo/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV VIRTUAL_ENV=/opt/dynemo/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
ARG VLLM_PATCH="vllm_${VLLM_REF}-dynemo-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
......@@ -100,25 +100,25 @@ COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c && \
cargo build --release --locked && cargo doc --no-deps
# Build triton_distributed wheel
RUN source /opt/triton/venv/bin/activate && \
# Build dynemo wheel
RUN source /opt/dynemo/venv/bin/activate && \
uv build --wheel --out-dir /workspace/dist && \
uv pip install /workspace/dist/triton_distributed*cp312*.whl
uv pip install /workspace/dist/dynemo*cp312*.whl
# Package the bindings
RUN mkdir -p /opt/triton/bindings/wheels && \
mkdir /opt/triton/bindings/lib && \
cp dist/triton_distributed*cp312*.whl /opt/triton/bindings/wheels/. && \
cp lib/bindings/c/target/release/libtriton_distributed_llm_capi.so /opt/triton/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/triton/bindings/.
RUN mkdir -p /opt/dynemo/bindings/wheels && \
mkdir /opt/dynemo/bindings/lib && \
cp dist/dynemo*cp312*.whl /opt/dynemo/bindings/wheels/. && \
cp lib/bindings/c/target/release/libdynemo_llm_capi.so /opt/dynemo/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/dynemo/bindings/.
# Tell vllm to use the Triton LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/triton/bindings/lib/libtriton_distributed_llm_capi.so"
# Tell vllm to use the Dynemo LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/dynemo/bindings/lib/libdynemo_llm_capi.so"
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
# FIXME: May want a modification with triton-distributed banner on entry
# FIXME: May want a modification with dynemo-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
......@@ -136,10 +136,10 @@ RUN apt update -y && \
echo "set -g mouse on" >> /root/.tmux.conf
# Set environment variables
ENV VIRTUAL_ENV=/opt/triton/venv
ENV VIRTUAL_ENV=/opt/dynemo/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV VLLM_KV_CAPI_PATH="/opt/triton/bindings/lib/libtriton_distributed_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/dynemo/bindings/lib/libdynemo_llm_capi.so"
# Copy binaries
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
......@@ -166,7 +166,7 @@ COPY examples/python_rs/llm/vllm /workspace/examples/python_rs/llm/vllm
WORKDIR /workspace
# FIXME: May want a modification with triton-distributed banner on entry
# FIXME: May want a modification with dynemo-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
......@@ -150,17 +150,17 @@ ENV PATH=/usr/local/bin/etcd/:$PATH
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
RUN mkdir /opt/dynemo && \
uv venv /opt/dynemo/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV VIRTUAL_ENV=/opt/dynemo/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
ARG VLLM_PATCH="vllm_${VLLM_REF}-dynemo-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
......@@ -225,25 +225,25 @@ COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c && \
cargo build --release --locked && cargo doc --no-deps
# Build triton_distributed wheel
RUN source /opt/triton/venv/bin/activate && \
# Build dynemo wheel
RUN source /opt/dynemo/venv/bin/activate && \
uv build --wheel --out-dir /workspace/dist && \
uv pip install /workspace/dist/triton_distributed*cp312*.whl
uv pip install /workspace/dist/dynemo*cp312*.whl
# Package the bindings
RUN mkdir -p /opt/triton/bindings/wheels && \
mkdir /opt/triton/bindings/lib && \
cp dist/triton_distributed*cp312*.whl /opt/triton/bindings/wheels/. && \
cp lib/bindings/c/target/release/libtriton_distributed_llm_capi.so /opt/triton/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/triton/bindings/.
RUN mkdir -p /opt/dynemo/bindings/wheels && \
mkdir /opt/dynemo/bindings/lib && \
cp dist/dynemo*cp312*.whl /opt/dynemo/bindings/wheels/. && \
cp lib/bindings/c/target/release/libdynemo_llm_capi.so /opt/dynemo/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/dynemo/bindings/.
# Tell vllm to use the Triton LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/triton/bindings/lib/libtriton_distributed_llm_capi.so"
# Tell vllm to use the Dynemo LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/dynemo/bindings/lib/libdynemo_llm_capi.so"
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
# FIXME: May want a modification with triton-distributed banner on entry
# FIXME: May want a modification with dynemo-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
......@@ -261,10 +261,10 @@ RUN apt update -y && \
echo "set -g mouse on" >> /root/.tmux.conf
# Set environment variables
ENV VIRTUAL_ENV=/opt/triton/venv
ENV VIRTUAL_ENV=/opt/dynemo/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV VLLM_KV_CAPI_PATH="/opt/triton/bindings/lib/libtriton_distributed_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/dynemo/bindings/lib/libdynemo_llm_capi.so"
# Copy binaries
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
......@@ -291,7 +291,7 @@ COPY examples/python_rs/llm/vllm_nixl /workspace/examples/python_rs/llm/vllm_nix
WORKDIR /workspace
# FIXME: May want a modification with triton-distributed banner on entry
# FIXME: May want a modification with dynemo-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
......@@ -16,7 +16,7 @@
TENSORRTLLM_BACKEND_REPO_TAG=
TENSORRTLLM_BACKEND_REBUILD=
TRITON_LLM_PATH=
DYNEMO_LLM_PATH=
GIT_TOKEN=
GIT_REPO=
......@@ -43,9 +43,9 @@ get_options() {
missing_requirement $1
fi
;;
--triton-llm-path)
--dynemo-llm-path)
if [ "$2" ]; then
TRITON_LLM_PATH=$2
DYNEMO_LLM_PATH=$2
shift
else
missing_requirement $1
......@@ -147,9 +147,9 @@ if [ ! -z ${TENSORRTLLM_BACKEND_REBUILD} ]; then
# Build the backend
(cd inflight_batcher_llm/src \
&& cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DUSE_CXX11_ABI=1 -DTRITON_LLM_PATH=$TRITON_LLM_PATH .. \
&& cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DUSE_CXX11_ABI=1 -DDYNEMO_LLM_PATH=$DYNEMO_LLM_PATH .. \
&& make install \
&& cp libtriton_tensorrtllm.so /opt/tritonserver/backends/tensorrtllm/ \
&& cp libdynemo_tensorrtllm.so /opt/tritonserver/backends/tensorrtllm/ \
&& cp trtllmExecutorWorker /opt/tritonserver/backends/tensorrtllm/ \
)
fi
......
......@@ -31,7 +31,7 @@ index 9ba49757..a2f88854 100644
f"and `kv_both`")
- if self.kv_connector is not None and self.kv_role is None:
+ if self.kv_connector is not None and self.kv_connector != "TritonNixlConnector" and self.kv_role is None:
+ if self.kv_connector is not None and self.kv_connector != "DynemoNixlConnector" and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
......@@ -44,7 +44,7 @@ index 9ba49757..a2f88854 100644
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
+ if self.kv_connector == "TritonNixlConnector":
+ if self.kv_connector == "DynemoNixlConnector":
+ return False
return self.kv_connector is not None and self.kv_parallel_size > 1
......@@ -277,7 +277,7 @@ index 00000000..350453cd
+logger = logging.getLogger(__name__)
+
+
+class TritonResult:
+class DynemoResult:
+ OK = 0
+ ERR = 1
+
......@@ -290,12 +290,12 @@ index 00000000..350453cd
+
+ try:
+ self.lib = ctypes.CDLL(lib_path)
+ self.lib.triton_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.triton_llm_init.restype = c_uint32
+ self.lib.dynemo_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
+ self.lib.dynemo_llm_init.restype = c_uint32
+
+ result = self.lib.triton_llm_init(namespace.encode(),
+ result = self.lib.dynemo_llm_init(namespace.encode(),
+ component.encode(), worker_id)
+ if result == TritonResult.OK:
+ if result == DynemoResult.OK:
+ logger.info(
+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
+ )
......@@ -306,7 +306,7 @@ index 00000000..350453cd
+ print(f"Failed to load {lib_path}")
+ raise e
+
+ self.lib.triton_kv_event_publish_stored.argtypes = [
+ self.lib.dynemo_kv_event_publish_stored.argtypes = [
+ ctypes.c_uint64, # event_id
+ ctypes.POINTER(ctypes.c_uint32), # token_ids
+ ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
......@@ -315,14 +315,14 @@ index 00000000..350453cd
+ ctypes.POINTER(ctypes.c_uint64), # parent_hash
+ ctypes.c_uint64, # lora_id
+ ]
+ self.lib.triton_kv_event_publish_stored.restype = ctypes.c_uint32 # triton_llm_result_t
+ self.lib.dynemo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynemo_llm_result_t
+
+ self.lib.triton_kv_event_publish_removed.argtypes = [
+ self.lib.dynemo_kv_event_publish_removed.argtypes = [
+ ctypes.c_uint64, # event_id
+ ctypes.POINTER(ctypes.c_uint64), # block_ids
+ ctypes.c_size_t, # num_blocks
+ ]
+ self.lib.triton_kv_event_publish_removed.restype = ctypes.c_uint32 # triton_llm_result_t
+ self.lib.dynemo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynemo_llm_result_t
+
+ self.event_id_counter = 0
+
......@@ -336,7 +336,7 @@ index 00000000..350453cd
+ if parent is not None else None)
+
+ # Publish the event
+ result = self.lib.triton_kv_event_publish_stored(
+ result = self.lib.dynemo_kv_event_publish_stored(
+ self.event_id_counter, # uint64_t event_id
+ token_ids_arr, # const uint32_t *token_ids
+ num_block_tokens, # const uintptr_t *num_block_tokens
......@@ -346,7 +346,7 @@ index 00000000..350453cd
+ 0, # uint64_t lora_id
+ )
+
+ if result == TritonResult.OK:
+ if result == DynemoResult.OK:
+ logger.debug(f"Store - Published KV Event: {block.content_hash}")
+ else:
+ logger.debug(
......@@ -355,28 +355,23 @@ index 00000000..350453cd
+ self.event_id_counter += 1
+
+ def enqueue_removed_event(self, block_hash: PrefixHash):
+ result = self.lib.triton_kv_event_publish_removed(
+ result = self.lib.dynemo_kv_event_publish_removed(
+ self.event_id_counter,
+ (ctypes.c_uint64 * 1)(block_hash),
+ 1,
+ )
+
+ if result == TritonResult.OK:
+ if result == DynemoResult.OK:
+ logger.debug(f"Remove - Published KV Event: {block_hash}")
+ else:
+ logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
+
+ self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..abe574d1 100644
index f507847a..ee20d50c 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -4,22 +4,22 @@ import enum
import os
import random
import time
+import copy
from collections import deque
@@ -8,18 +8,17 @@ from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
......@@ -398,7 +393,7 @@ index f507847a..abe574d1 100644
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
@@ -325,12 +325,14 @@ class Scheduler:
@@ -325,12 +324,14 @@ class Scheduler:
def __init__(
self,
......@@ -413,7 +408,7 @@ index f507847a..abe574d1 100644
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +358,7 @@ class Scheduler:
@@ -356,6 +357,7 @@ class Scheduler:
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
......@@ -421,7 +416,7 @@ index f507847a..abe574d1 100644
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +374,16 @@ class Scheduler:
@@ -371,6 +373,14 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
......@@ -429,8 +424,6 @@ index f507847a..abe574d1 100644
+ # Sequence groups in the REMOTE_PREFILLING state.
+ # Contain requests that are being prefilled by a remote worker.
+ self.remote_prefilling: Deque[SequenceGroup] = deque()
+ # Contain requests that are being prefilled by a local worker.
+ self.prefill_sending: Deque[SequenceGroup] = deque()
+
+ self._remote_prefill_outputs: Dict[str, int] = {}
+
......@@ -438,25 +431,24 @@ index f507847a..abe574d1 100644
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
@@ -501,7 +514,7 @@ class Scheduler:
@@ -501,7 +511,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
- self.swapped) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +536,8 @@ class Scheduler:
@@ -523,6 +533,7 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
+ finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
+ finished_prefills: Optional[Set[str]] = None
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
@@ -537,6 +552,8 @@ class Scheduler:
@@ -537,6 +548,8 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
......@@ -465,7 +457,7 @@ index f507847a..abe574d1 100644
Returns:
SchedulerRunningOutputs.
@@ -566,6 +583,38 @@ class Scheduler:
@@ -566,6 +579,24 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
......@@ -476,7 +468,6 @@ index f507847a..abe574d1 100644
+ if seq_group.request_id not in finished_prefills:
+ leftover_remote_prefilling_sequences.append(seq_group)
+ continue
+
+ else:
+ finished_prefills.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
......@@ -487,63 +478,39 @@ index f507847a..abe574d1 100644
+ seq.data._stage = SequenceStage.DECODE
+ self.running.appendleft(seq_group)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+
+ remote_transfers_queue = self.prefill_sending
+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
+ while remote_transfers_queue:
+ seq_group = remote_transfers_queue.popleft()
+ if seq_group.request_id not in finished_transfers:
+ leftover_remote_transfers_sequences.append(seq_group)
+ else:
+ finished_transfers.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ seq = seq_group.seqs[0]
+ self.free_seq(seq)
+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
+
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
@@ -1008,7 +1057,17 @@ class Scheduler:
@@ -1008,7 +1039,7 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
- self._allocate_and_set_running(seq_group)
+
+ seq_group_copy = copy.deepcopy(seq_group)
+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
+
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group)
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
+ self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +1107,7 @@ class Scheduler:
@@ -1048,7 +1079,7 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
- def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
@@ -1090,7 +1149,9 @@ class Scheduler:
@@ -1090,7 +1121,8 @@ class Scheduler:
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
curr_loras,
- enable_chunking=False)
+ enable_chunking=False,
+ finished_prefills=finished_prefills,
+ finished_transfers=finished_transfers)
+ finished_prefills=finished_prefills)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
@@ -1106,7 +1167,12 @@ class Scheduler:
@@ -1106,7 +1138,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
......@@ -557,31 +524,30 @@ index f507847a..abe574d1 100644
self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1314,14 @@ class Scheduler:
@@ -1248,12 +1285,14 @@ class Scheduler:
len(running_scheduled.swapped_out)),
)
- def _schedule(self) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
+ if finished_prefills or finished_transfers:
+ if finished_prefills:
+ raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill()
else:
- return self._schedule_default()
+ return self._schedule_default(finished_prefills, finished_transfers)
+ return self._schedule_default(finished_prefills)
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
@@ -1287,14 +1355,16 @@ class Scheduler:
@@ -1287,14 +1326,15 @@ class Scheduler:
return no_single_seq
def schedule(
- self
+ self,
+ finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
+ finished_prefills: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
......@@ -590,11 +556,11 @@ index f507847a..abe574d1 100644
- scheduler_outputs: SchedulerOutputs = self._schedule()
+ scheduler_start_time = time.perf_counter()
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers)
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills)
now = time.time()
if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1403,8 @@ class Scheduler:
@@ -1333,7 +1373,8 @@ class Scheduler:
encoder_seq_data = None
cross_block_table = None
......@@ -604,24 +570,18 @@ index f507847a..abe574d1 100644
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,9 +1435,16 @@ class Scheduler:
@@ -1364,6 +1405,10 @@ class Scheduler:
< seqs[0].data.get_len()):
do_sample = False
+ is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
+
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
+ logger.debug("Assinged blocks: %s", block_tables)
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@@ -1392,6 +1470,7 @@ class Scheduler:
@@ -1392,6 +1437,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
......@@ -629,7 +589,7 @@ index f507847a..abe574d1 100644
)
else:
# When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +1569,13 @@ class Scheduler:
@@ -1490,10 +1536,13 @@ class Scheduler:
self._async_stopped.clear()
......@@ -645,80 +605,12 @@ index f507847a..abe574d1 100644
def _append_slots(self,
seq_group: SequenceGroup,
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644
index 00000000..9b938039
--- /dev/null
+++ b/vllm/distributed/device_communicators/kv_rearrange.py
@@ -0,0 +1,61 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def rearrange_kernel(
+ t1_ptr,
+ t2_ptr,
+ N,
+ B,
+ H,
+ C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+
+ curr_n = offsets // block_size
+ curr_b = offsets // token_size % B
+ curr_h = offsets // C % H
+ curr_c = offsets % C
+
+ src_pos = offsets
+
+ tp_group = curr_h * d // H
+ dst_h = curr_h % (H // d)
+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
+
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset
+
+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
+
+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int):
+ N, B, H, C = t1.shape
+
+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
+ assert H % d == 0, "H must be divisible by d"
+
+ block_size = B * H * C
+ token_size = H * C
+ tensor_size = N * block_size
+ tensor_subset_size = tensor_size // d
+
+ BLOCK_SIZE = 1024
+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
+
+ rearrange_kernel[grid](
+ t1, t2,
+ N, B, H, C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE=BLOCK_SIZE
+ )
\ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
index 00000000..f1618bc4
index 00000000..bc962726
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,318 @@
@@ -0,0 +1,249 @@
+import torch
+from typing import List, Tuple
+from vllm.config import VllmConfig
......@@ -726,18 +618,39 @@ index 00000000..f1618bc4
+import msgspec
+import time
+import uuid
+from collections import defaultdict
+from .kv_rearrange import rearrange_tensors
+from nixl_wrapper import nixl_wrapper as NixlWrapper
+
+logger = init_logger(__name__)
+
+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
+try:
+ from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore
+ logger.info("NIXL is available")
+except ImportError:
+ logger.warning("NIXL is not available")
+ NixlWrapper = None
+
+def nixl_wrapper_init_patch(self, agent_name, nixl_config):
+ logger.info("Initializing patched NixlWrapper")
+ import nixl_bindings as nixl
+ # Read available backends and device info from nixl_config
+ # For now setting the multithreading to enabled.
+ devices = nixl.nixlAgentConfig(False)
+ init = nixl.nixlUcxInitParams()
+
+ self.name = agent_name
+ self.notifs = {}
+ self.backends = {}
+ self.agent = nixl.nixlAgent(agent_name, devices)
+ self.backends["UCX"] = self.agent.createBackend(init)
+
+ self.nixl_mems = {"DRAM": nixl.DRAM_SEG,
+ "VRAM": nixl.VRAM_SEG,
+ "cpu": nixl.DRAM_SEG,
+ "cuda": nixl.VRAM_SEG}
+ self.nixl_ops = {"WRITE": nixl.NIXL_WR_FLUSH,
+ "READ": nixl.NIXL_RD_FLUSH,
+ "WRITE_NOTIF": nixl.NIXL_WR_NOTIF,
+ "READ_NOTIF": nixl.NIXL_RD_NOTIF}
+
+ print("Initializied NIXL agent:", agent_name)
+
+NixlWrapper.__init__ = nixl_wrapper_init_patch
+
+
+
+class NixlMetadata(
+ msgspec.Struct,
......@@ -749,20 +662,14 @@ index 00000000..f1618bc4
+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
+
+
+class TritonNixlConnector:
+class DynemoNixlConnector:
+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
+ self.vllm_config = vllm_config
+ if NixlWrapper is None:
+ logger.error("NIXL is not available")
+ raise RuntimeError("NIXL is not available")
+ logger.info("Initializing NIXL wrapper")
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+
+ self.num_layers = None
+ self.num_blocks = None
+ self.num_heads = None
+ self.block_len = None
+ self.kv_caches = None
+ self.kv_caches_base_addr = {}
+ self.kv_cache_shape = {}
+
......@@ -771,51 +678,33 @@ index 00000000..f1618bc4
+ self.engine_id = engine_id
+ self.rank = rank
+ self.notifs = {}
+ self._tp_size = {}
+ self._block_descs = {}
+ self._xfer_side_handles = {}
+
+
+ self._transfers = defaultdict(list)
+
+
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+
+
+ @property
+ def agent_name(self):
+ return self.nixl_wrapper.name
+
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ caches_data = []
+ self.num_layers = len(kv_caches)
+ _, _, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ self.num_layers = len(kv_caches)
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+
+ kv_caches_base_addr = []
+ caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr()
+ region_len = num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank))
+ for block_id in range(self.num_blocks):
+ blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+
+ region_len = cache.numel() * cache.element_size()
+ gpu_id = cache.get_device()
+ assert gpu_id > -1, "Tensor is not on GPU"
+ caches_data.append((base_addr, region_len, gpu_id))
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata()
+
......@@ -825,14 +714,10 @@ index 00000000..f1618bc4
+ for agent_name in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp):
+ self._tp_size[engine_id] = agent_tp
+ agent_names = []
+ for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+ def add_remote_agent(self, engine_id, agent_metadata):
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_metadata)
+ self._remote_agents[engine_id] = agent_name
+ return agent_name
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
......@@ -847,29 +732,17 @@ index 00000000..f1618bc4
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids
+
+ def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0):
+ if rank is None:
+ rank = self.rank
+ offset_block_len = self.block_len
+ block_len = self.block_len // tp_multiplier
+ tp_offset = i * block_len
+ else:
+ offset_block_len = self.block_len // tp_multiplier
+ block_len = self.block_len // tp_multiplier
+ tp_offset = 0
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ def _get_range_descs(self, engine_id, ranges, layer_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ blocks_data = []
+ for layer_id in layer_ids:
+ for range_start, range_end in ranges:
+ range_len = range_end - range_start + 1
+ key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ start_offset = range_start * offset_block_len + tp_offset * range_len
+ blocks_len = range_len * block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ key_base_addr, value_base_addr = self.kv_caches_base_addr[engine_id][layer_id]
+ start_offset = range_start * self.block_len
+ blocks_len = (range_end - range_start + 1) * self.block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, self.rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, self.rank))
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+
+ def _get_ranges(self, block_ids):
......@@ -882,9 +755,9 @@ index 00000000..f1618bc4
+ ranges = []
+ for i in range(len(sorted_block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ ranges.append([sorted_block_ids[i], sorted_block_ids[i]])
+ ranges.append([sorted_block_ids[i]])
+ else:
+ ranges[-1][1] = sorted_block_ids[i]
+ ranges[-1].append(sorted_block_ids[i])
+ return ranges
+
+ def _get_same_length_ranges(self, src_ranges, dst_ranges):
......@@ -924,24 +797,11 @@ index 00000000..f1618bc4
+ src_idx += 1
+
+ return src_overlapping_ranges, dst_overlapping_ranges
+
+
+
+ def _get_block_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id)
+ return descs_ids
+
+
+
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg, use_prepped_xfer=False):
+ def transfer_mem(self, src_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
......@@ -950,62 +810,44 @@ index 00000000..f1618bc4
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+
+ if use_prepped_xfer:
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id]
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id]
+
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True)
+ # else:
+ # Legacy path using range-based transfers
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
+ src_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(src_ranges, dst_ranges)
+
+ assert len(src_ranges) == 1
+ assert len(staging_ranges) == 1
+ logger.debug("Got %s overlapping ranges for %s blocks", len(src_overlapping_ranges), len(src_block_ids))
+
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
+ src_range_start, src_range_end = src_ranges[0]
+ src_range_len = src_range_end - src_range_start + 1
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+ logger.debug("Time to get ranges: %s ms", time.perf_counter() - start_time)
+
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+ src_descs = self._get_range_descs(self.engine_id, src_overlapping_ranges, "all")
+ dst_descs = self._get_range_descs(dst_engine_id, dst_overlapping_ranges, "all")
+
+ staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ for i in range(tp_multiplier):
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs, self._remote_agents[dst_engine_id], notify_msg, "WRITE")
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ # TODO ptarasiewicz: remove blocking transfer mem
+ # add scheduler check for transfer done
+ while True:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "ERR":
+ raise RuntimeError("Transfer failed")
+ elif xfer_state == "DONE":
+ logger.debug("Transfer done")
+ break
+ elif xfer_state == "PROC":
+ time.sleep(0.01)
+ else:
+ raise RuntimeError("Unknown transfer state")
+ logger.debug("Time to wait for transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ self.nixl_wrapper.abort_xfer(handle)
+ logger.debug("Time to abort xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer time: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg, "WRITE")
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+
......@@ -1018,26 +860,6 @@ index 00000000..f1618bc4
+
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
+
+ def get_done_tranfers(self) -> List[str]:
+ done_req_ids = []
+ for req_id, handles in self._transfers.items():
+ running_reqs = []
+ for handle in handles:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "DONE":
+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ continue
+ if xfer_state == "PROC":
+ running_reqs.append(handle)
+ else:
+ raise RuntimeError("Transfer failed with state %s", xfer_state)
+ if len(running_reqs) == 0:
+ done_req_ids.append(req_id)
+ else:
+ self._transfers[req_id] = running_reqs
+ return done_req_ids
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
......@@ -1064,9 +886,9 @@ index fe480533..61a357d0 100644
"SimpleConnector")
+
+KVConnectorFactory.register_connector(
+ "TritonNcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.triton_connector",
+ "TritonConnector")
+ "DynemoNcclConnector",
+ "vllm.distributed.kv_transfer.kv_connector.dynemo_connector",
+ "DynemoConnector")
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..e33919c1 100644
......@@ -1396,7 +1218,7 @@ index 2033e976..e33919c1 100644
+ world_group.broadcast_object(kv_config_enhanced)
+
+ else:
+ raise NotImplementedError("MooncakeConnector is not supported in Triton Distributed vllm patch")
+ raise NotImplementedError("MooncakeConnector is not supported in Dynemo Distributed vllm patch")
+ else:
+ kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
......@@ -1407,11 +1229,11 @@ index 2033e976..e33919c1 100644
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/triton_connector.py b/vllm/distributed/kv_transfer/kv_connector/triton_connector.py
diff --git a/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
new file mode 100644
index 00000000..cb3b3660
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/triton_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/dynemo_connector.py
@@ -0,0 +1,350 @@
+# SPDX-License-Identifier: Apache-2.0
+"""
......@@ -1443,7 +1265,7 @@ index 00000000..cb3b3660
+logger = init_logger(__name__)
+
+
+class TritonConnector(KVConnectorBase):
+class DynemoConnector(KVConnectorBase):
+
+ def __init__(
+ self,
......@@ -1457,16 +1279,16 @@ index 00000000..cb3b3660
+ self.tp_size = config.parallel_config.tensor_parallel_size
+ self.rank = rank
+
+ if self.config.kv_connector != "TritonNcclConnector":
+ raise NotImplementedError("Only TritonNcclConnector is supported by the TritonConnector class")
+ if self.config.kv_connector != "DynemoNcclConnector":
+ raise NotImplementedError("Only DynemoNcclConnector is supported by the DynemoConnector class")
+
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
+ PyNcclPipe)
+ from vllm.distributed.kv_transfer.kv_pipe.triton_nccl_pipe import (
+ TritonNcclDataPlane)
+ from vllm.distributed.kv_transfer.kv_pipe.dynemo_nccl_pipe import (
+ DynemoNcclDataPlane)
+
+ logger.info(
+ "Initializing TritonNcclConnector under kv_transfer_config %s",
+ "Initializing DynemoNcclConnector under kv_transfer_config %s",
+ self.config)
+
+ self.lookup_buffer_size = self.config.kv_buffer_size
......@@ -1498,7 +1320,7 @@ index 00000000..cb3b3660
+ port_offset=port_offset_base,
+ )
+
+ self.data_plane = TritonNcclDataPlane(
+ self.data_plane = DynemoNcclDataPlane(
+ data_pipe=self.data_pipe,
+ port=self._get_data_plane_port(self.global_kv_rank),
+ )
......@@ -2233,11 +2055,11 @@ index 7aa53d07..f5dd50b7 100644
def close(self):
"""
diff --git a/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py
diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
new file mode 100644
index 00000000..8a356504
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_pipe/triton_nccl_pipe.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/dynemo_nccl_pipe.py
@@ -0,0 +1,124 @@
+import logging
+import threading
......@@ -2253,7 +2075,7 @@ index 00000000..8a356504
+logger = logging.getLogger(__name__)
+
+
+class TritonNcclDataPlane:
+class DynemoNcclDataPlane:
+ def __init__(
+ self,
+ data_pipe: PyNcclPipe,
......@@ -2399,7 +2221,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..254337cb 100644
index d82d9ad9..9ba1a326 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@
......@@ -2482,13 +2304,11 @@ index d82d9ad9..254337cb 100644
+ self.engine_id = str(uuid.uuid4())
+ self._nixl_agents_names: Optional[List[str]] = None
+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "TritonNixlConnector":
+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynemoNixlConnector":
+ self._nixl_agents_names = self._initialize_nixl()
+
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._finished_prefills = set()
+ self._finished_transfers = set()
+
+ @property
+ def is_nixl_initialized(self) -> bool:
......@@ -2507,6 +2327,8 @@ index d82d9ad9..254337cb 100644
+ engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ if len(agents_metadata) != len(self._nixl_agents_names):
+ raise ValueError("Number of agents does not match. Make sure all engines are initialized with the same parallel sizes.")
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+
+ def _initialize_nixl(self) -> List[bytes]:
......@@ -2540,16 +2362,7 @@ index d82d9ad9..254337cb 100644
ParallelSampleSequenceGroup.add_request(
request_id,
self,
@@ -574,6 +624,8 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
+ next(self.seq_counter) # empty sequence for staging
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +636,7 @@ class LLMEngine:
@@ -584,7 +634,7 @@ class LLMEngine:
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
......@@ -2558,7 +2371,7 @@ index d82d9ad9..254337cb 100644
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@@ -601,8 +653,12 @@ class LLMEngine:
@@ -601,8 +651,12 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
......@@ -2572,7 +2385,7 @@ index d82d9ad9..254337cb 100644
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
@@ -673,6 +729,7 @@ class LLMEngine:
@@ -673,6 +727,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -2580,7 +2393,7 @@ index d82d9ad9..254337cb 100644
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@@ -765,6 +822,7 @@ class LLMEngine:
@@ -765,6 +820,7 @@ class LLMEngine:
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
......@@ -2588,7 +2401,7 @@ index d82d9ad9..254337cb 100644
)
def _validate_token_prompt(self, prompt: PromptType,
@@ -799,6 +857,7 @@ class LLMEngine:
@@ -799,6 +855,7 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
......@@ -2596,7 +2409,7 @@ index d82d9ad9..254337cb 100644
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@@ -829,7 +888,9 @@ class LLMEngine:
@@ -829,7 +886,9 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
......@@ -2607,7 +2420,7 @@ index d82d9ad9..254337cb 100644
return seq_group
@@ -995,11 +1056,11 @@ class LLMEngine:
@@ -995,11 +1054,11 @@ class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
......@@ -2621,7 +2434,7 @@ index d82d9ad9..254337cb 100644
# Sanity check
assert len(seq_group_metadata_list) == len(
@@ -1325,15 +1386,49 @@ class LLMEngine:
@@ -1325,15 +1384,49 @@ class LLMEngine:
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
......@@ -2641,7 +2454,7 @@ index d82d9ad9..254337cb 100644
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule()
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers)
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills)
+
+
+ # Separate remote prefill and running seq groups
......@@ -2673,7 +2486,7 @@ index d82d9ad9..254337cb 100644
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +1478,31 @@ class LLMEngine:
@@ -1383,9 +1476,29 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
......@@ -2687,11 +2500,9 @@ index d82d9ad9..254337cb 100644
+ req_id = scheduled_seq_group.seq_group.request_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id]
+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
+ memory_transfer_req = MemoryTransferRequest(
+ request_id=req_id,
+ src_block_ids=block_table,
+ staging_block_ids=staging_block_ids,
+ dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id,
......@@ -2701,13 +2512,13 @@ index d82d9ad9..254337cb 100644
+
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs
+
+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
+ outputs, request_notif_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req)
-
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
@@ -1396,7 +1513,26 @@ class LLMEngine:
@@ -1396,7 +1509,20 @@ class LLMEngine:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
......@@ -2718,7 +2529,7 @@ index d82d9ad9..254337cb 100644
+ blocks_to_swap_out=[],
+ blocks_to_copy=[])
+
+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
+ outputs, request_notif_counter = self.model_executor.execute_model(
+ execute_model_req=execute_model_req)
+
+ for req_id, notif_count in request_notif_counter.items():
......@@ -2726,16 +2537,10 @@ index d82d9ad9..254337cb 100644
+ if self._request_notif_counter[req_id] > -1:
+ self._finished_prefills.add(req_id)
+ del self._request_notif_counter[req_id]
+
+ for req_id, done_count in request_done_counter.items():
+ self._request_done_counter[req_id] += done_count
+ if self._request_done_counter[req_id] > -1:
+ self._finished_transfers.add(req_id)
+ del self._request_done_counter[req_id]
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@@ -1456,7 +1592,7 @@ class LLMEngine:
@@ -1456,7 +1582,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
......@@ -2813,7 +2618,7 @@ index 3cf1850e..6b90ece7 100644
+ kv_active_blocks: int
+ kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..c501e4c8 100644
index 85b5f31e..d33d546a 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
......@@ -2895,7 +2700,7 @@ index 85b5f31e..c501e4c8 100644
+
+ @property
+ def using_nixl_connector(self) -> bool:
+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "TritonNixlConnector"
+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynemoNixlConnector"
+
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
......@@ -2941,7 +2746,7 @@ index 85b5f31e..c501e4c8 100644
+ kv_metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks)
+
+ logger.debug("Metircs successful.")
+ logger.debug("Metircs successful.")
+
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
......@@ -3346,10 +3151,10 @@ index 786380c3..56a7cf89 100644
"""The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644
index 00000000..957f55de
index 00000000..03f02006
--- /dev/null
+++ b/vllm/remote_prefill.py
@@ -0,0 +1,54 @@
@@ -0,0 +1,53 @@
+from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine
+
......@@ -3387,7 +3192,6 @@ index 00000000..957f55de
+ """
+ request_id: str
+ src_block_ids: List[int]
+ staging_block_ids: List[int]
+ dst_block_ids: List[int]
+ dst_engine_id: str
+ notify_msg: str
......@@ -3531,7 +3335,7 @@ index 12baecde..cbada27f 100644
if self.vllm_config.kv_transfer_config is None:
return False
+
+ if self.vllm_config.kv_transfer_config.kv_connector == "TritonNixlConnector":
+ if self.vllm_config.kv_transfer_config.kv_connector == "DynemoNixlConnector":
+ return False
prefill_meta = model_input.attn_metadata.prefill_metadata
......@@ -3541,13 +3345,13 @@ index 12baecde..cbada27f 100644
if self.vllm_config.kv_transfer_config is None:
return False
+
+ if self.vllm_config.kv_transfer_config.kv_connector == "TritonNixlConnector":
+ if self.vllm_config.kv_transfer_config.kv_connector == "DynemoNixlConnector":
+ return False
prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..1b8515bf 100644
index 582aa460..ffb7b403 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@
......@@ -3563,7 +3367,7 @@ index 582aa460..1b8515bf 100644
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
+from vllm.distributed.device_communicators.nixl import TritonNixlConnector
+from vllm.distributed.device_communicators.nixl import DynemoNixlConnector
+
logger = init_logger(__name__)
......@@ -3577,7 +3381,7 @@ index 582aa460..1b8515bf 100644
+ # TODO ptarasiewicz nixl can also support DRAM
+ assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector"
+
+ self.nixl_connector = TritonNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector = DynemoNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
+ assert len(self.cache_engine) == 1, "Only one cache engine is supported for now"
+ self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache)
+ return self.nixl_connector.agent_name
......@@ -3588,13 +3392,13 @@ index 582aa460..1b8515bf 100644
+
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata)) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata[self.local_rank]) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr[self.local_rank])
+ return agent_name
+
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name[self.local_rank], notify_msg) # TODO ptarasiewicz: rank or local_rank?
+
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
......@@ -3602,8 +3406,8 @@ index 582aa460..1b8515bf 100644
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ if worker_input.src_block_ids is not None:
+ for src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.staging_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+ for src_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+
+ def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
......@@ -3621,12 +3425,11 @@ index 582aa460..1b8515bf 100644
return WorkerInput(
num_seq_groups=num_seq_groups,
@@ -375,6 +416,11 @@ class Worker(LocalOrDistributedWorkerBase):
@@ -375,6 +416,10 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs],
......@@ -3634,7 +3437,7 @@ index 582aa460..1b8515bf 100644
@torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..ecb68530 100644
index 819b81fb..d9c039eb 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
......@@ -3649,7 +3452,7 @@ index 819b81fb..ecb68530 100644
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
+from vllm.distributed.device_communicators.nixl import TritonNixlConnector
+from vllm.distributed.device_communicators.nixl import DynemoNixlConnector
logger = init_logger(__name__)
......@@ -3657,17 +3460,16 @@ index 819b81fb..ecb68530 100644
from vllm.platforms import current_platform
self.current_platform = current_platform
+ self.nixl_connector: Optional[TritonNixlConnector] = None
+ self.nixl_connector: Optional[DynemoNixlConnector] = None
+
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,12 @@ class WorkerInput:
@@ -216,6 +220,11 @@ class WorkerInput:
virtual_engine: int = 0
num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None
+ staging_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None
......@@ -3675,31 +3477,29 @@ index 819b81fb..ecb68530 100644
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
@@ -232,6 +242,11 @@ class WorkerInput:
@@ -232,6 +241,10 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"),
+ staging_block_ids=tensor_dict.pop("staging_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"),
)
def as_broadcastable_tensor_dict(
@@ -246,6 +261,11 @@ class WorkerInput:
@@ -246,6 +259,10 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids,
+ "staging_block_ids": self.staging_block_ids,
+ "dst_block_ids": self.dst_block_ids,
+ "dst_engine_id": self.dst_engine_id,
+ "notify_msg": self.notify_msg,
}
return tensor_dict
@@ -316,13 +336,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@@ -316,13 +333,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
......@@ -3721,7 +3521,7 @@ index 819b81fb..ecb68530 100644
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
@@ -396,49 +419,87 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@@ -396,49 +416,79 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
......@@ -3818,7 +3618,7 @@ index 819b81fb..ecb68530 100644
+ else:
+ for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs:
+ for req_ids in notifs.values():
......@@ -3827,20 +3627,12 @@ index 819b81fb..ecb68530 100644
+
+ if request_notif_counter:
+ logger.debug("Request notif counter: %s", request_notif_counter)
+
+ request_done_counter = defaultdict(int)
+ for req_id in self.nixl_connector.get_done_tranfers():
+ request_done_counter[req_id] += 1
+
+ if request_done_counter:
+ logger.debug("Request done counter: %s", request_done_counter)
+
+ else:
+ request_notif_counter = {}
+ request_done_counter = {}
# output is List[SamplerOutput]
- return output
+ return output, request_notif_counter, request_done_counter
+ return output, request_notif_counter
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ pass
......
......@@ -26,7 +26,8 @@ import typing as t
from typing import Any
import click
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from dynemo.runtime import DistributedRuntime, dynemo_endpoint, dynemo_worker
logger = logging.getLogger("compoundai.serve.nova")
......@@ -102,7 +103,7 @@ def main(
server_context.worker_index = worker_id
class_instance = service.inner()
@triton_worker()
@dynemo_worker()
async def worker(runtime: DistributedRuntime):
if service_name and service_name != service.name:
server_context.service_type = "service"
......@@ -157,12 +158,12 @@ def main(
# Bind an instance of inner to the endpoint
bound_method = endpoint.func.__get__(class_instance)
# Only pass request type for now, use Any for response
# TODO: Handle a triton_endpoint not having types
# TODO: Handle a dynemo_endpoint not having types
# TODO: Handle multiple endpoints in a single component
triton_wrapped_method = triton_endpoint(endpoint.request_type, Any)(
dynemo_wrapped_method = dynemo_endpoint(endpoint.request_type, Any)(
bound_method
)
result = await td_endpoint.serve_endpoint(triton_wrapped_method)
result = await td_endpoint.serve_endpoint(dynemo_wrapped_method)
# WARNING: unreachable code :( because serve blocks
logger.info(f"[{run_id}] Result: {result}")
logger.info(f"[{run_id}] Registered endpoint '{name}'")
......
......@@ -50,7 +50,7 @@ class NovaEndpoint:
if isinstance(args[1], (str, dict)):
args[1] = self.request_type.parse_obj(args[1]) # type: ignore
# Convert Pydantic model to dict before passing to triton
# Convert Pydantic model to dict before passing to dynemo
if len(args) > 1 and isinstance(args[1], BaseModel):
args = list(args) # type: ignore
args[1] = args[1].model_dump() # type: ignore
......
......@@ -72,9 +72,9 @@ class NovaClient:
else:
# Create nova worker if no runtime
from triton_distributed_rs import DistributedRuntime, triton_worker
from dynemo.runtime import DistributedRuntime, dynemo_worker
@triton_worker()
@dynemo_worker()
async def stream_worker(runtime: DistributedRuntime):
try:
# Store runtime for future use
......
......@@ -90,14 +90,14 @@ Note: NATS and ETCD servers should be running and accessible from the container
Run the server logging (with debug level logging):
```bash
TRD_LOG=DEBUG http &
DYN_LOG=DEBUG http &
```
By default the server will run on port 8080.
Add model to the server:
```bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0 triton-init.tensorrt-llm.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0 triton-init.tensorrt-llm.completions
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0 dynemo.tensorrt-llm.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0 dynemo.tensorrt-llm.completions
```
#### 2. Workers
......@@ -214,14 +214,14 @@ Run the container interactively with the following command:
Run the server logging (with debug level logging):
```bash
TRD_LOG=DEBUG http &
DYN_LOG=DEBUG http &
```
By default the server will run on port 8080.
Add model to the server:
```bash
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0 triton-init.router.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0 triton-init.router.completions
llmctl http add chat TinyLlama/TinyLlama-1.1B-Chat-v1.0 dynemo.router.chat/completions
llmctl http add completion TinyLlama/TinyLlama-1.1B-Chat-v1.0 dynemo.router.completions
```
#### 2. Workers
......
......@@ -19,12 +19,12 @@ import asyncio
import uvloop
from triton_distributed.runtime import DistributedRuntime, triton_worker
from dynemo.runtime import DistributedRuntime, dynemo_worker
from .protocol import Request
@triton_worker()
@dynemo_worker()
async def worker(
runtime: DistributedRuntime,
component: str,
......@@ -38,7 +38,7 @@ async def worker(
"""
# create client
client = (
await runtime.namespace("triton-init")
await runtime.namespace("dynemo")
.component(component)
.endpoint("generate")
.client()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment