Commit f6d723f6 authored by Blazej's avatar Blazej Committed by GitHub
Browse files

feat: Add vLLM workers to LLM example (#41)

Add example of LLM disaggregated serving
parent 8980ec37
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
**/*.plan **/*.plan
**/.cache/* **/.cache/*
**/*onnx* **/*onnx*
**/*engine* # Engine must be allowed because code contains triton_distributed_engine.py
#**/*engine*
**/*pytorch_model* **/*pytorch_model*
**/*.pth* **/*.pth*
**/*.pt **/*.pt
......
...@@ -100,6 +100,21 @@ ENV TENSORRTLLM_BACKEND_COMMIT=$TENSORRTLLM_BACKEND_COMMIT ...@@ -100,6 +100,21 @@ ENV TENSORRTLLM_BACKEND_COMMIT=$TENSORRTLLM_BACKEND_COMMIT
# TODO set VLLM Version # TODO set VLLM Version
# ENV VLLM_VERSION # ENV VLLM_VERSION
ARG VLLM_FRAMEWORK
# DEFAULT VLLM VARIABLES
ENV VLLM_ATTENTION_BACKEND=${VLLM_FRAMEWORK:+FLASHINFER}
ENV VLLM_WORKER_MULTIPROC_METHOD=${VLLM_FRAMEWORK:+spawn}
ENV VLLM_TORCH_HOST=${VLLM_FRAMEWORK:+localhost}
ENV VLLM_TORCH_PORT=${VLLM_FRAMEWORK:+36183}
ENV VLLM_DATA_PLANE_BACKEND=${VLLM_FRAMEWORK:+nccl}
ENV VLLM_BASELINE_WORKERS=${VLLM_FRAMEWORK:+0}
ENV VLLM_CONTEXT_WORKERS=${VLLM_FRAMEWORK:+1}
ENV VLLM_GENERATE_WORKERS=${VLLM_FRAMEWORK:+1}
ENV VLLM_BASELINE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_CONTEXT_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_GENERATE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_LOGGING_LEVEL=${VLLM_FRAMEWORK:+INFO}
ENV PYTHONUNBUFFERED=1
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability # Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
......
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
# Necessary for vLLM engine. # Necessary for vLLM engine.
--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4 --extra-index-url https://flashinfer.ai/whl/cu121/torch2.4
flashinfer flashinfer<0.2.0
# Necessary for vLLM engine.
ninja==1.11.1.3
ucx-py-cu12 ucx-py-cu12
# vLLM is installed by patching script # vLLM is installed by patching script
# vllm==0.6.3post1 # vllm==0.6.3post1
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -49,6 +49,9 @@ For more details on the basics of Triton Distributed, please see the [Hello Worl ...@@ -49,6 +49,9 @@ For more details on the basics of Triton Distributed, please see the [Hello Worl
- For FP8 usage, GPUs with **Compute Capability >= 8.9** are required. - For FP8 usage, GPUs with **Compute Capability >= 8.9** are required.
- If you have older GPUs, consider BF16/FP16 precision variants instead of `FP8`. (See [below](#model-precision-variants).) - If you have older GPUs, consider BF16/FP16 precision variants instead of `FP8`. (See [below](#model-precision-variants).)
5. **HuggingFace**
- You need a HuggingFace account to download the model and set HF_TOKEN environment variable.
--- ---
## 2. Building the Environment ## 2. Building the Environment
...@@ -56,7 +59,7 @@ For more details on the basics of Triton Distributed, please see the [Hello Worl ...@@ -56,7 +59,7 @@ For more details on the basics of Triton Distributed, please see the [Hello Worl
The example is designed to run in a containerized environment using Triton Distributed, vLLM, and associated dependencies. To build the container: The example is designed to run in a containerized environment using Triton Distributed, vLLM, and associated dependencies. To build the container:
```bash ```bash
./container/build.sh --framework VLLM ./container/build.sh --framework vllm
``` ```
This command pulls necessary dependencies and patches vLLM in the container image. This command pulls necessary dependencies and patches vLLM in the container image.
...@@ -67,45 +70,58 @@ This command pulls necessary dependencies and patches vLLM in the container imag ...@@ -67,45 +70,58 @@ This command pulls necessary dependencies and patches vLLM in the container imag
Below is a minimal example of how to start each component of a disaggregated serving setup. The typical sequence is: Below is a minimal example of how to start each component of a disaggregated serving setup. The typical sequence is:
2. **Start the Context Worker(s) and Request Plane**
3. **Start the Generate Worker(s)**
1. **Start the API Server** (handles incoming requests and coordinates workers) 1. **Start the API Server** (handles incoming requests and coordinates workers)
2. **Start the Prefill Worker(s)**
3. **Start the Decode Worker(s)**
All components must be able to connect to the same NATS server to coordinate. All components must be able to connect to the same request plane to coordinate.
### 3.1 API Server ### 3.1 HuggingFace Token
The API server in a vLLM-disaggregated setup listens for OpenAI-compatible requests on a chosen port (default 8005). Below is an example command: ```bash
export HF_TOKEN=<YOUR TOKEN>
```
### 3.2 Launch Interactive Environment
```bash ```bash
python3 -m examples.api_server \ ./container/run.sh --framework vllm -it
--nats-url nats://localhost:4223 \
--log-level INFO \
--port 8005
``` ```
### 3.2 Prefill Worker Note: all subsequent commands will be run in the same container for simplicity
The prefill stage encodes incoming prompts. By default, vLLM uses GPU resources to tokenize and prepare the model’s key-value (KV) caches. Run the prefill worker: Note: by default this command makes all gpu devices visible. Use flag
```bash ```
--gpus
```
to selectively make gpu devices visible.
### 3.2 Launch Context Worker and Request Plane
The context stage encodes incoming prompts. By default, vLLM uses GPU resources to tokenize and prepare the model’s key-value (KV) caches.
Within the container start the context worker and the request plane:
```
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=0 \
VLLM_WORKER_ID=0 \ VLLM_WORKER_ID=0 \
python3 -m examples.vllm.deploy \ python3 -m llm.vllm.deploy \
--context-worker-count 1 \ --context-worker-count 1 \
--nats-url nats://localhost:4223 \ --request-plane-uri ${HOSTNAME}:4223 \
--model-name neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \ --model-name neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--kv-cache-dtype fp8 \ --kv-cache-dtype fp8 \
--dtype auto \ --dtype auto \
--log-level INFO \
--worker-name llama \ --worker-name llama \
--disable-async-output-proc \ --disable-async-output-proc \
--disable-log-stats \ --disable-log-stats \
--max-model-len 32768 \ --max-model-len 3500 \
--max-batch-size 10000 \ --max-batch-size 10000 \
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.9 \
--context-tp-size 1 \ --context-tp-size 1 \
--generate-tp-size 1 --generate-tp-size 1 \
--initialize-request-plane &
``` ```
**Key flags**: **Key flags**:
...@@ -113,50 +129,115 @@ python3 -m examples.vllm.deploy \ ...@@ -113,50 +129,115 @@ python3 -m examples.vllm.deploy \
- `--kv-cache-dtype fp8`: Using FP8 for caching (requires CC >= 8.9). - `--kv-cache-dtype fp8`: Using FP8 for caching (requires CC >= 8.9).
- `CUDA_VISIBLE_DEVICES=0`: Binds worker to GPU `0`. - `CUDA_VISIBLE_DEVICES=0`: Binds worker to GPU `0`.
### 3.3 Decode Worker #### Expected Output
```
<SNIP>
Workers started ... press Ctrl-C to Exit
[168] 2025/01/24 09:17:38.879908 [INF] Starting nats-server
[168] 2025/01/24 09:17:38.879982 [INF] Version: 2.10.24
[168] 2025/01/24 09:17:38.879987 [INF] Git: [1d6f7ea]
[168] 2025/01/24 09:17:38.879989 [INF] Name: NDBCCXARM6D2BMMRJOKZCJD4TGVXXPCJKQRXALJOPHLA5W7ISCW4VHU5
[168] 2025/01/24 09:17:38.879992 [INF] Node: S4g51H7K
[168] 2025/01/24 09:17:38.879995 [INF] ID: NDBCCXARM6D2BMMRJOKZCJD4TGVXXPCJKQRXALJOPHLA5W7ISCW4VHU5
[168] 2025/01/24 09:17:38.880339 [INF] Starting JetStream
<SNIP>
INFO 01-24 09:17:49 parallel_state.py:942] Stage: PREFILL
```
### 3.3 Launch Generate (Decode) Worker
The generate stage consumes the KV cache produced in the context step and generates output tokens.
Within the container start the generate worker:
The decode stage consumes the KV cache produced in the prefill step and generates output tokens. Run the decode worker:
```bash ```bash
CUDA_VISIBLE_DEVICES=1 \ CUDA_VISIBLE_DEVICES=1 \
VLLM_WORKER_ID=1 \ VLLM_WORKER_ID=1 \
python3 -m examples.vllm.deploy \ python3 -m llm.vllm.deploy \
--generate-worker-count 1 \ --generate-worker-count 1 \
--nats-url nats://localhost:4223 \ --request-plane-uri ${HOSTNAME}:4223 \
--model-name neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \ --model-name neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--kv-cache-dtype fp8 \ --kv-cache-dtype fp8 \
--dtype auto \ --dtype auto \
--log-level INFO \
--worker-name llama \ --worker-name llama \
--disable-async-output-proc \ --disable-async-output-proc \
--disable-log-stats \ --disable-log-stats \
--max-model-len 32768 \ --max-model-len 3500 \
--max-batch-size 10000 \ --max-batch-size 10000 \
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.9 \
--context-tp-size 1 \ --context-tp-size 1 \
--generate-tp-size 1 --generate-tp-size 1 &
``` ```
> [!NOTE]
> - First time running in a newly launched container will
> include model download. Please wait until you see the
> llama handler started before sending requests
**Key flags**: **Key flags**:
- `--generate-worker-count`: Launches decode worker(s). - `--generate-worker-count`: Launches decode worker(s).
- `CUDA_VISIBLE_DEVICES=1`: Binds worker to GPU `1`. - `CUDA_VISIBLE_DEVICES=1`: Binds worker to GPU `1`.
#### Expected Output
```
<SNIP>x
model-00002-of-00002.safetensors: 100% 4.08G/4.08G [01:36<00:00, 42.2MB/s]
model-00001-of-00002.safetensors: 100%% 4.71G/5.00G [01:51<00:06, 41.9MB/s]
<SNIP>
INFO 01-24 09:21:22 model_runner.py:1406] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 01-24 09:21:22 model_runner.py:1410] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
<SNIP>
09:22:10 worker.py:266[Triton Worker] INFO: Worker started...
09:22:10 worker.py:241[Triton Worker] INFO: Starting generate handler...
09:22:10 worker.py:266[Triton Worker] INFO: Worker started...
09:22:10 worker.py:241[Triton Worker] INFO: Starting llama handler...
```
> [!NOTE] > [!NOTE]
> - You can run multiple prefill and decode workers for higher throughput. > - You can run multiple prefill and decode workers for higher throughput.
> - For large models, ensure you have enough GPU memory (or GPUs). > - For large models, ensure you have enough GPU memory (or GPUs).
### 3.4 API Server
The API server in a vLLM-disaggregated setup listens for OpenAI-compatible requests on a chosen port (default 8005). Below is an example command:
```bash
python3 -m llm.api_server \
--tokenizer neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--request-plane-uri ${HOSTNAME}:4223 \
--api-server-host ${HOSTNAME} \
--model-name llama \
--api-server-port 8005 &
```
#### Expected Output
```
[WARNING] Adding CORS for the following origins: ['http://localhost']
INFO: Started server process [498]
INFO: Waiting for application startup.
TRACE: ASGI [1] Started scope={'type': 'lifespan', 'asgi': {'version': '3.0', 'spec_version': '2.0'}, 'state': {}}
TRACE: ASGI [1] Receive {'type': 'lifespan.startup'}
TRACE: ASGI [1] Send {'type': 'lifespan.startup.complete'}
INFO: Application startup complete.
INFO: Uvicorn running on http://2u2g-gen-0349:8005 (Press CTRL+C to quit)
```
## 4. Sending Requests ## 4. Sending Requests
Once the API server is running (by default on `localhost:8005`), you can send OpenAI-compatible requests. For example: Once the API server is running (by default on `localhost:8005`), you can send OpenAI-compatible requests. For example:
```bash ```bash
curl localhost:8005/v1/chat/completions \ curl ${HOSTNAME}:8005/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "llama", "model": "llama",
"messages": [ "messages": [
{"role": "system", "content": "What is the capital of France?"} {"role": "user", "content": "What is the capital of France?"}
], ],
"temperature": 0, "temperature": 0,
"top_p": 0.95, "top_p": 0.95,
...@@ -171,6 +252,25 @@ curl localhost:8005/v1/chat/completions \ ...@@ -171,6 +252,25 @@ curl localhost:8005/v1/chat/completions \
The above request will return a streamed response with the model’s answer. The above request will return a streamed response with the model’s answer.
#### Expected Output
```
INFO 01-24 09:33:05 async_llm_engine.py:207] Added request 052eabe0-fc54-4f7c-9be8-4926523b26fc___0.
INFO 01-24 09:33:05 kv_cache.py:378] Fetching source address for worker 0 by key worker_0_rank_0
TRACE: 127.0.0.1:49878 - ASGI [2] Send {'type': 'http.response.body', 'body': '<290 bytes>', 'more_body': True}
data: {"id":"052eabe0-fc54-4f7c-9be8-4926523b26fc","choices":[{"delta":{"content":"\n\n","role":"assistant"},"logprobs":null,"finish_reason":null,"index":0}],"created":1737711185,"model":"llama","system_fingerprint":"052eabe0-fc54-4f7c-9be8-4926523b26fc","object":"chat.completion.chunk"}
INFO 01-24 09:33:05 async_llm_engine.py:175] Finished request 052eabe0-fc54-4f7c-9be8-4926523b26fc___0.
TRACE: 127.0.0.1:49878 - ASGI [2] Send {'type': 'http.response.body', 'body': '<317 bytes>', 'more_body': True}
TRACE: 127.0.0.1:49878 - ASGI [2] Send {'type': 'http.response.body', 'body': '<14 bytes>', 'more_body': True}
TRACE: 127.0.0.1:49878 - ASGI [2] Send {'type': 'http.response.body', 'body': '<0 bytes>', 'more_body': False}
data: {"id":"052eabe0-fc54-4f7c-9be8-4926523b26fc","choices":[{"delta":{"content":"The capital of France is Paris.","role":"assistant"},"logprobs":null,"finish_reason":null,"index":0}],"created":1737711185,"model":"llama","system_fingerprint":"052eabe0-fc54-4f7c-9be8-4926523b26fc","object":"chat.completion.chunk"}
TRACE: 127.0.0.1:49878 - ASGI [2] Receive {'type': 'http.disconnect'}
data: [DONE]
```
## 5. Benchmarking ## 5. Benchmarking
You can benchmark this setup using [**GenAI-Perf**](https://github.com/triton-inference-server/perf_analyzer/blob/main/genai-perf/README.md), which supports OpenAI endpoints for chat or completion requests. You can benchmark this setup using [**GenAI-Perf**](https://github.com/triton-inference-server/perf_analyzer/blob/main/genai-perf/README.md), which supports OpenAI endpoints for chat or completion requests.
...@@ -178,7 +278,7 @@ You can benchmark this setup using [**GenAI-Perf**](https://github.com/triton-in ...@@ -178,7 +278,7 @@ You can benchmark this setup using [**GenAI-Perf**](https://github.com/triton-in
```bash ```bash
genai-perf profile \ genai-perf profile \
-m llama \ -m llama \
--url <API_SERVER_HOST>:8005 \ --url ${HOSTNAME}:8005 \
--endpoint-type chat \ --endpoint-type chat \
--streaming \ --streaming \
--num-dataset-entries 1000 \ --num-dataset-entries 1000 \
...@@ -189,8 +289,8 @@ genai-perf profile \ ...@@ -189,8 +289,8 @@ genai-perf profile \
--synthetic-input-tokens-stddev 0 \ --synthetic-input-tokens-stddev 0 \
--output-tokens-stddev 0 \ --output-tokens-stddev 0 \
--tokenizer neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8 \ --tokenizer neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8 \
--synthetic-input-tokens-mean 3000 \ --synthetic-input-tokens-mean 300 \
--output-tokens-mean 150 \ --output-tokens-mean 3000 \
--extra-inputs seed:100 \ --extra-inputs seed:100 \
--extra-inputs min_tokens:150 \ --extra-inputs min_tokens:150 \
--extra-inputs max_tokens:150 \ --extra-inputs max_tokens:150 \
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import signal
import sys
import time
from pathlib import Path
from llm.vllm.operators.vllm import (
VllmBaselineOperator,
VllmContextOperator,
VllmGenerateOperator,
)
from triton_distributed.worker import Deployment, OperatorConfig, WorkerConfig
from .parser import parse_args
deployment = None
def handler(signum, frame):
exit_code = 0
if deployment:
print("Stopping Workers")
exit_code = deployment.stop()
print(f"Workers Stopped Exit Code {exit_code}")
sys.exit(exit_code)
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for sig in signals:
try:
signal.signal(sig, handler)
except Exception:
pass
def _create_context_op(name, args, max_inflight_requests):
return OperatorConfig(
name=name,
implementation=VllmContextOperator,
max_inflight_requests=int(max_inflight_requests),
parameters=vars(args),
)
def _create_generate_op(name, args, max_inflight_requests):
return OperatorConfig(
name=name,
implementation=VllmGenerateOperator,
max_inflight_requests=int(max_inflight_requests),
parameters=vars(args),
)
def _create_baseline_op(name, args, max_inflight_requests):
return OperatorConfig(
name=name,
implementation=VllmBaselineOperator,
max_inflight_requests=int(max_inflight_requests),
parameters=vars(args),
)
def main(args):
global deployment
if args.log_dir:
log_dir = Path(args.log_dir)
log_dir.mkdir(exist_ok=True)
worker_configs = []
# Context/Generate workers used for Disaggregated Serving
if args.context_worker_count == 1:
context_op = _create_context_op(args.worker_name, args, 1000)
context = WorkerConfig(
operators=[context_op],
# Context worker gets --worker-name as it is the model that will
# be hit first in a disaggregated setting.
name=args.worker_name,
)
worker_configs.append((context, 1))
if args.generate_worker_count == 1:
generate_op = _create_generate_op("generate", args, 1000)
generate = WorkerConfig(
operators=[generate_op],
# Generate worker gets a hard-coded name "generate" as the context
# worker will talk directly to it.
name="generate",
)
worker_configs.append((generate, 1))
# NOTE: Launching baseline worker and context/generate workers at
# the same time is not currently supported.
if args.baseline_worker_count == 1:
# Baseline worker has a hard-coded name just for testing purposes
baseline_op = _create_baseline_op("baseline", args, 1000)
baseline = WorkerConfig(
operators=[baseline_op],
name="baseline",
)
worker_configs.append((baseline, 1))
deployment = Deployment(
worker_configs,
initialize_request_plane=args.initialize_request_plane,
log_dir=args.log_dir,
log_level=args.log_level,
starting_metrics_port=args.starting_metrics_port,
request_plane_args=([], {"request_plane_uri": args.request_plane_uri}),
)
deployment.start()
print("Workers started ... press Ctrl-C to Exit")
while True:
time.sleep(10)
if __name__ == "__main__":
args = parse_args()
main(args)
#!/bin/bash
# FIXME: Convert this script to README steps
export VLLM_ATTENTION_BACKEND=FLASHINFER
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_TORCH_HOST=localhost
export VLLM_TORCH_PORT=36183
export VLLM_BASELINE_WORKERS=1
export VLLM_BASELINE_TP_SIZE=1
export VLLM_LOGGING_LEVEL=INFO
export VLLM_DATA_PLANE_BACKEND=nccl
export PYTHONUNBUFFERED=1
export NATS_HOST=localhost
export NATS_PORT=4223
export NATS_STORE="$(mktemp -d)"
export API_SERVER_HOST=localhost
export API_SERVER_PORT=8005
# Start NATS Server
echo "Flushing NATS store: ${NATS_STORE}..."
rm -r "${NATS_STORE}"
echo "Starting NATS Server..."
nats-server -p ${NATS_PORT} --jetstream --store_dir "${NATS_STORE}" &
# Start API Server
echo "Starting LLM API Server..."
python3 -m llm.api_server \
--tokenizer neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--request-plane-uri ${NATS_HOST}:${NATS_PORT} \
--api-server-host ${API_SERVER_HOST} \
--model-name "baseline" \
--api-server-port ${API_SERVER_PORT} &
# Empty --log-dir will dump logs to stdout
echo "Starting vLLM baseline workers..."
CUDA_VISIBLE_DEVICES=0 \
VLLM_WORKER_ID=0 \
python3 -m llm.vllm.deploy \
--baseline-worker-count ${VLLM_BASELINE_WORKERS} \
--request-plane-uri ${NATS_HOST}:${NATS_PORT} \
--model-name neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--kv-cache-dtype fp8 \
--dtype auto \
--disable-async-output-proc \
--disable-log-stats \
--max-model-len 1000 \
--max-batch-size 10000 \
--gpu-memory-utilization 0.9 \
--baseline-tp-size ${VLLM_BASELINE_TP_SIZE} \
--log-dir ""
# NOTE: It may take more than a minute for the vllm worker to start up
# if the model weights aren't cached and need to be downloaded.
echo "Waiting for deployment to finish startup..."
sleep 60
# Make a Chat Completion Request
echo "Sending chat completions request..."
curl ${API_SERVER_HOST}:${API_SERVER_PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "baseline",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
],
"temperature": 0,
"top_p": 0.95,
"max_tokens": 25,
"stream": true,
"n": 1,
"frequency_penalty": 0.0,
"stop": []
}'
#!/bin/bash
# FIXME: Convert this script to README steps
export VLLM_ATTENTION_BACKEND=FLASHINFER
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_TORCH_HOST=localhost
export VLLM_TORCH_PORT=36183
export VLLM_BASELINE_WORKERS=0
export VLLM_CONTEXT_WORKERS=1
export VLLM_GENERATE_WORKERS=1
export VLLM_BASELINE_TP_SIZE=1
export VLLM_CONTEXT_TP_SIZE=1
export VLLM_GENERATE_TP_SIZE=1
export VLLM_LOGGING_LEVEL=INFO
export VLLM_DATA_PLANE_BACKEND=nccl
export PYTHONUNBUFFERED=1
export NATS_HOST=localhost
export NATS_PORT=4223
export NATS_STORE="$(mktemp -d)"
export API_SERVER_HOST=localhost
export API_SERVER_PORT=8005
# Start NATS Server
echo "Flushing NATS store: ${NATS_STORE}..."
rm -r "${NATS_STORE}"
echo "Starting NATS Server..."
nats-server -p ${NATS_PORT} --jetstream --store_dir "${NATS_STORE}" &
# Start API Server
echo "Starting LLM API Server..."
python3 -m llm.api_server \
--tokenizer neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--request-plane-uri ${NATS_HOST}:${NATS_PORT} \
--api-server-host ${API_SERVER_HOST} \
--model-name llama \
--api-server-port ${API_SERVER_PORT} &
# Start VLLM Worker 0
echo "Starting vLLM context workers..."
CUDA_VISIBLE_DEVICES=0 \
VLLM_WORKER_ID=0 \
python3 -m llm.vllm.deploy \
--context-worker-count ${VLLM_CONTEXT_WORKERS} \
--request-plane-uri ${NATS_HOST}:${NATS_PORT} \
--model-name neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--kv-cache-dtype fp8 \
--dtype auto \
--worker-name llama \
--disable-async-output-proc \
--disable-log-stats \
--max-model-len 1000 \
--max-batch-size 10000 \
--gpu-memory-utilization 0.9 \
--context-tp-size ${VLLM_CONTEXT_TP_SIZE} \
--generate-tp-size ${VLLM_GENERATE_TP_SIZE} \
--log-dir "/tmp/vllm_logs" &
# Start VLLM Worker 1
echo "Starting vLLM generate workers..."
CUDA_VISIBLE_DEVICES=1 \
VLLM_WORKER_ID=1 \
python3 -m llm.vllm.deploy \
--generate-worker-count ${VLLM_GENERATE_WORKERS} \
--request-plane-uri ${NATS_HOST}:${NATS_PORT} \
--model-name neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 \
--kv-cache-dtype fp8 \
--dtype auto \
--worker-name llama \
--disable-async-output-proc \
--disable-log-stats \
--max-model-len 1000 \
--max-batch-size 10000 \
--gpu-memory-utilization 0.9 \
--context-tp-size ${VLLM_CONTEXT_TP_SIZE} \
--generate-tp-size ${VLLM_GENERATE_TP_SIZE} \
--log-dir "/tmp/vllm_logs" &
# NOTE: It may take more than a minute for the vllm worker to start up
# if the model weights aren't cached and need to be downloaded.
echo "Waiting for deployment to finish startup..."
echo "Once you see all ranks connected to the server, it should be ready..."
echo "Example output:"
echo "\tRank 0 connected to the server"
echo "\t..."
echo "\tRank 1 connected to the server"
sleep 120
# Make a Chat Completion Request
echo "Sending chat completions request..."
curl ${API_SERVER_HOST}:${API_SERVER_PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama",
"messages": [
{"role": "system", "content": "What is the capital of France?"}
],
"temperature": 0,
"top_p": 0.95,
"max_tokens": 25,
"stream": true,
"n": 1,
"frequency_penalty": 0.0,
"stop": []
}'
import argparse
# FIXME: Remove unused args if any
def parse_args():
parser = argparse.ArgumentParser(description="Run an example of the VLLM pipeline.")
# example_dir = Path(__file__).parent.absolute().parent.absolute()
# default_log_dir = "" example_dir.joinpath("logs")
default_log_dir = ""
parser = argparse.ArgumentParser(description="Hello World Deployment")
parser.add_argument(
"--log-dir",
type=str,
default=str(default_log_dir),
help="log dir folder",
)
parser.add_argument(
"--request-plane-uri",
type=str,
default="nats://localhost:4223",
help="URI of request plane",
)
parser.add_argument(
"--initialize-request-plane",
default=False,
action="store_true",
help="Initialize the request plane, should only be done once per deployment",
)
parser.add_argument(
"--starting-metrics-port",
type=int,
default=0,
help="Metrics port for first worker. Each worker will expose metrics on subsequent ports, ex. worker 1: 50000, worker 2: 50001, worker 3: 50002",
)
parser.add_argument(
"--context-worker-count",
type=int,
required=False,
default=0,
help="Number of context workers",
)
parser.add_argument(
"--dummy-worker-count",
type=int,
required=False,
default=0,
help="Number of dummy workers",
)
parser.add_argument(
"--baseline-worker-count",
type=int,
required=False,
default=0,
help="Number of baseline workers",
)
parser.add_argument(
"--generate-worker-count",
type=int,
required=False,
default=0,
help="Number of generate workers",
)
parser.add_argument(
"--nats-url",
type=str,
required=False,
default="nats://localhost:4223",
help="URL of NATS server",
)
parser.add_argument(
"--model-name",
type=str,
required=False,
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
help="Model name",
)
parser.add_argument(
"--worker-name",
type=str,
required=False,
default="llama",
help="Worker name",
)
parser.add_argument(
"--max-model-len",
type=int,
required=False,
default=None,
help="Maximum input/output latency length.",
)
parser.add_argument(
"--max-batch-size",
type=int,
required=False,
default=10000,
help="Max batch size",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
required=False,
default=0.45,
help="GPU memory utilization (fraction of memory from 0.0 to 1.0)",
)
parser.add_argument(
"--dtype",
type=str,
required=False,
default="float16",
help="Attention data type (float16, TODO: fp8)",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
required=False,
default="auto",
help="Key-value cache data type",
)
# FIXME: Support string values like 'debug', 'info, etc.
parser.add_argument(
"--log-level",
type=int,
required=False,
choices=[0, 1, 2],
default=1,
help="Logging level: 2=debug, 1=info, 0=error (default=1)",
)
## Logical arguments for vLLM engine
parser.add_argument(
"--enable-prefix-caching",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Enable prefix caching",
)
parser.add_argument(
"--enable-chunked-prefill",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Enable chunked prefill",
)
parser.add_argument(
"--enforce-eager",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Enforce eager execution",
)
parser.add_argument(
"--ignore-eos",
action=argparse.BooleanOptionalAction,
required=False,
default=False,
help="Ignore EOS token when generating",
)
parser.add_argument(
"--baseline-tp-size",
type=int,
default=1,
help="Tensor parallel size of a baseline worker.",
)
parser.add_argument(
"--context-tp-size",
type=int,
default=1,
help="Tensor parallel size of a context worker.",
)
parser.add_argument(
"--generate-tp-size",
type=int,
default=1,
help="Tensor parallel size of a generate worker.",
)
parser.add_argument(
"--max-num-seqs",
type=int,
default=None,
help="maximum number of sequences per iteration",
)
parser.add_argument(
"--disable-async-output-proc",
action="store_true",
help="Disable async output processing",
)
parser.add_argument(
"--disable-log-stats",
action="store_true",
help="Disable logging statistics",
)
return parser.parse_args()
import argparse
from dataclasses import field
from typing import Any, Optional
from triton_distributed.icp.data_plane import DataPlane
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker import Operator, RemoteInferenceRequest
from .vllm_disaggregated.pipelines import (
GenerateStage,
PrefillStage,
SingleComputePipeline,
)
from .vllm_disaggregated.stage_executor import PiplineStageExecutor
class VllmContextOperator(Operator):
def __init__(
self,
name: str,
version: int,
triton_core,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[Any] = None,
):
args = argparse.Namespace(**parameters) # type: ignore
stage = PrefillStage(
model=args.model_name,
tensor_parallel_size=args.context_tp_size,
generate_tensor_parallel_size=args.generate_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
kv_cache_dtype=args.kv_cache_dtype,
enable_prefix_caching=args.enable_prefix_caching,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
ignore_eos=args.ignore_eos,
max_num_seqs=args.max_num_seqs,
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
self.executor = PiplineStageExecutor(
args, request_plane, stage, "prefill", "generate"
)
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
await self.executor.process_requests(requests)
class VllmGenerateOperator(Operator):
def __init__(
self,
name: str,
version: int,
triton_core,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[Any] = None,
):
args = argparse.Namespace(**parameters) # type: ignore
args.worker_name = "generate"
stage = GenerateStage(
model=args.model_name,
tensor_parallel_size=args.generate_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
kv_cache_dtype=args.kv_cache_dtype,
enable_prefix_caching=args.enable_prefix_caching,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
ignore_eos=args.ignore_eos,
max_num_seqs=args.max_num_seqs,
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
self.executor = PiplineStageExecutor(args, request_plane, stage, "generate")
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
await self.executor.process_requests(requests)
class VllmBaselineOperator(Operator):
def __init__(
self,
name: str,
version: int,
triton_core,
request_plane: RequestPlane,
data_plane: DataPlane,
parameters: Optional[dict[str, str | int | bool | bytes]] = field(
default_factory=dict
),
repository: Optional[str] = None,
logger: Optional[Any] = None,
):
args = argparse.Namespace(**parameters) # type: ignore
stage = SingleComputePipeline(
model=args.model_name,
tensor_parallel_size=args.baseline_tp_size,
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=args.max_model_len,
dtype=args.dtype,
kv_cache_dtype=args.kv_cache_dtype,
enable_prefix_caching=args.enable_prefix_caching,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
ignore_eos=args.ignore_eos,
max_num_seqs=args.max_num_seqs,
disable_async_output_proc=args.disable_async_output_proc,
disable_log_stats=args.disable_log_stats,
)
self.executor = PiplineStageExecutor(args, request_plane, stage, "baseline")
async def execute(self, requests: list[RemoteInferenceRequest]) -> None:
await self.executor.process_requests(requests)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import abc
import dataclasses
import typing
class TritonInferenceError(Exception):
"""Error occurred during Triton inference."""
@dataclasses.dataclass
class InferenceRequest:
"""Inference request."""
inputs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
parameters: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class InferenceResponse:
"""Inference response."""
outputs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
error: typing.Optional[str] = None
final: bool = False
parameters: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
class BaseTriton3Connector(abc.ABC):
"""Base class for Triton 3 connector."""
@abc.abstractmethod
def inference(
self, model_name: str, request: InferenceRequest
) -> typing.AsyncGenerator[InferenceResponse, None]:
"""Inference request to Triton 3 system.
Args:
model_name: Model name.
request: Inference request.
Returns:
Inference response.
Raises:
TritonInferenceError: error occurred during inference.
"""
raise NotImplementedError
async def list_models(self) -> typing.List[str]:
"""List models available in Triton 3 system.
Returns:
List of model names.
"""
raise NotImplementedError
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import time
from typing import Any, AsyncGenerator, Dict, Optional
import numpy as np
import vllm.engine.arg_utils
import vllm.engine.async_llm_engine
import vllm.inputs.data
LOGGER = vllm.logger.init_logger(__name__)
# FIXME currently streaming all the tokens is not efficient
# with RETURN_EVERY_N so large we return only first token and whole sequence at the end
RETURN_EVERY_N = 1000000
class SingleComputePipeline:
def __init__(
self,
**kwargs,
):
self._ignore_eos = kwargs.pop("ignore_eos", False)
engine_args = vllm.engine.arg_utils.AsyncEngineArgs(**kwargs)
LOGGER.info(f"Creating engine with args: {engine_args}")
self._engine = vllm.engine.async_llm_engine.AsyncLLMEngine.from_engine_args(
engine_args
)
LOGGER.info(f"Created engine: {self._engine}")
async def __call__(
self, input_payload: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
try:
vllm_input = input_payload["parameters"]["prompt"]
sampling_params = vllm.SamplingParams(
**input_payload["parameters"].get("sampling_params", {}),
ignore_eos=self._ignore_eos,
)
LOGGER.debug(f"sampling_params: {sampling_params}")
request_id = input_payload["parameters"].get("request_id", None)
results_generator = self._engine.generate(
vllm_input, sampling_params, request_id
)
LOGGER.debug("results_generator started")
counter = 0
async for result in results_generator:
if counter % RETURN_EVERY_N == 0 or result.finished:
tokens_ids = np.stack(
[output_row.token_ids for output_row in result.outputs]
).astype(np.int64)
LOGGER.debug(f"tokens_ids: {tokens_ids.shape}")
yield {
"outputs": {},
"error": None,
"final": result.finished,
"parameters": {
"text": result.outputs[0].text,
},
}
counter += 1
LOGGER.debug("results_generator finished")
except Exception as e:
LOGGER.error(f"Exception in SingleComputePipeline: {e}")
yield {"outputs": {}, "error": str(e), "final": True}
class PrefillStage:
def __init__(
self,
generate_tensor_parallel_size: Optional[int] = None,
**kwargs,
):
context_tensor_parallel_size = kwargs.get("tensor_parallel_size", 1)
generate_tensor_parallel_size = (
generate_tensor_parallel_size or context_tensor_parallel_size
)
assert (
generate_tensor_parallel_size % context_tensor_parallel_size == 0
), "generate_tensor_parallel_size must be multiple of context_tensor_parallel_size"
LOGGER.debug(f"context_tensor_parallel_size: {context_tensor_parallel_size}")
LOGGER.debug(f"generate_tensor_parallel_size: {generate_tensor_parallel_size}")
os.environ["VLLM_DISAGG_STAGE"] = "PREFILL"
os.environ["VLLM_CONTEXT_TP_SIZE"] = str(context_tensor_parallel_size)
os.environ["VLLM_GENERATE_TP_SIZE"] = str(generate_tensor_parallel_size)
LOGGER.info(f"Env VLLM_DISAGG_STAGE set to {os.environ['VLLM_DISAGG_STAGE']}")
kwargs[
"enforce_eager"
] = True # Prefill stage must be eager because of variable ISL
self._ignore_eos = kwargs.pop("ignore_eos", False)
engine_args = vllm.engine.arg_utils.AsyncEngineArgs(**kwargs)
LOGGER.info(f"Creating engine with args: {engine_args}")
self._engine = vllm.engine.async_llm_engine.AsyncLLMEngine.from_engine_args(
engine_args
)
LOGGER.info("Prefill stage initialized")
async def __call__(
self, input_payload: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
try:
vllm_input = input_payload["parameters"]["prompt"]
request_id = input_payload["parameters"].get("request_id", None)
assert request_id is not None, "request_id is required for prefill"
sampling_params = vllm.SamplingParams(
**input_payload["parameters"].get("sampling_params", {}),
ignore_eos=self._ignore_eos,
)
old_my_max_tokens = sampling_params.max_tokens
old_my_min_tokens = sampling_params.min_tokens
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
LOGGER.debug(f"sampling_params: {sampling_params}")
start_time_ns = time.monotonic_ns()
results_generator = self._engine.generate(
vllm_input, sampling_params, request_id
)
LOGGER.debug("results_generator started")
async for result in results_generator:
taken_ms = (time.monotonic_ns() - start_time_ns) / 1_000_000
LOGGER.info(
"==== Prefill completed kv cache taken %0.3fms ====", taken_ms
)
# TODO: needed to pass prompt, request_id, sampling_params to the next stage as there is no pipeline concept in online scenario
sampling_params.max_tokens = old_my_max_tokens
sampling_params.min_tokens = old_my_min_tokens
sampling_params_init_names = inspect.signature(
vllm.SamplingParams
).parameters.keys()
sampling_params = {
k: v
for k, v in sampling_params.__dict__.items()
if k in sampling_params_init_names
}
LOGGER.debug(
f"Yield response {input_payload['inputs'].keys()} parameters {input_payload['parameters']}"
)
yield {
"outputs": {}, # See line 195 for context
"error": None,
"parameters": {
**input_payload["parameters"],
"context_worker_id": os.environ["VLLM_WORKER_ID"],
"first_token": result.outputs[0].token_ids[0],
"seq_len": len(result.prompt_token_ids),
},
"final": True,
}
LOGGER.debug("Results generator for prefill finishes")
except Exception as e:
LOGGER.error(f"Exception in SingleComputePipeline: {e}")
yield {"outputs": {}, "error": str(e), "final": True}
class GenerateStage:
def __init__(
self,
**kwargs,
):
os.environ["VLLM_DISAGG_STAGE"] = "GENERATE"
LOGGER.info(f"Env VLLM_DISAGG_STAGE set to {os.environ['VLLM_DISAGG_STAGE']}")
self._ignore_eos = kwargs.pop("ignore_eos", False)
engine_args = vllm.engine.arg_utils.AsyncEngineArgs(**kwargs)
LOGGER.info(f"Creating engine with args: {engine_args}")
self._engine = vllm.engine.async_llm_engine.AsyncLLMEngine.from_engine_args(
engine_args
)
LOGGER.info("Generation stage initialized")
async def __call__(
self, input_payload: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
seq_len = input_payload["parameters"]["seq_len"]
LOGGER.debug(f"input sequence length: {seq_len}")
# we can use any tokens because first token is already sampled by the context worker
# and we just need the correct shape to allocate space in the kv cache
vllm_input = vllm.inputs.data.TokensPrompt(prompt_token_ids=[0] * seq_len)
sampling_params = vllm.SamplingParams(
**input_payload["parameters"].get("sampling_params", {}),
ignore_eos=self._ignore_eos,
)
LOGGER.debug(f"sampling_params: {sampling_params}")
request_id = input_payload["parameters"].get("request_id", None)
assert request_id is not None, "request_id is required for generate"
context_worker_id = input_payload["parameters"]["context_worker_id"]
new_request_id = f"{request_id}___{context_worker_id}"
first_token = input_payload["parameters"]["first_token"]
self._engine.engine.model_executor.driver_worker.model_runner.set_first_token(
new_request_id, first_token
)
# TODO ptarasiewicz this is only temporary way to pass worker id to the engine
# so that it can pull the correct kv cache
results_generator = self._engine.generate(
vllm_input,
sampling_params,
new_request_id,
)
LOGGER.debug("results_generator started")
counter = 0
async for result in results_generator:
if counter % RETURN_EVERY_N == 0 or result.finished:
yield {
"outputs": {},
"error": None,
"final": result.finished,
"parameters": {
"text": result.outputs[0].text,
},
}
counter += 1
LOGGER.debug("results_generator finished for generate")
class DisaggregatedPipeline:
def __init__(
self,
stage,
**kwargs,
):
if stage == "prefill":
LOGGER.info(f"initialize prefill {kwargs}")
self.stage = PrefillStage(**kwargs) # type: ignore
elif stage == "generate":
LOGGER.info(f"initialize generate {kwargs}")
self.stage = GenerateStage(**kwargs) # type: ignore
else:
raise ValueError(f"Unknown stage: {stage}")
async def __call__(
self, input_payload: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
LOGGER.debug("Start pipeline")
async for result in self.stage(input_payload):
LOGGER.debug("yield result")
yield result
LOGGER.debug("Pipeline generator finished")
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import Optional
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.icp.ucp_data_plane import UcpDataPlane
# UCP data plane causes deadlocks when used more than once, so we use a singleton
_g_singletonic_data_plane = None
_g_singletonic_data_plane_connection_count = 0
_g_actual_host = None
_g_actual_port = None
def set_actual_host_port(host, port):
global _g_actual_host
global _g_actual_port
if _g_singletonic_data_plane is not None:
raise Exception("Cannot set actual host and port after data plane is created")
_g_actual_host = host
_g_actual_port = port
def set_data_plane(data_plane):
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
_g_singletonic_data_plane_connection_count = 1
_g_singletonic_data_plane = data_plane
class RemoteConnector:
"""Handle connection to both request and data planes."""
def __init__(
self,
request_plane: RequestPlane,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize RemoteConnector.
Args:
nats_url (str): URL of NATS server.
"""
global _g_singletonic_data_plane
global _g_actual_port
global _g_actual_host
self._request_plane = request_plane
if _g_singletonic_data_plane is None:
if _g_actual_host is not None:
data_plane_host = _g_actual_host
if _g_actual_port is not None:
data_plane_port = _g_actual_port
_g_singletonic_data_plane = UcpDataPlane(
hostname=data_plane_host,
port=data_plane_port,
keep_endpoints_open=keep_dataplane_endpoints_open,
)
self._connected = False
self._data_plane = _g_singletonic_data_plane
async def connect(self):
"""Connect to both request and data planes."""
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
assert _g_singletonic_data_plane
if _g_singletonic_data_plane_connection_count == 0:
_g_singletonic_data_plane.connect()
_g_singletonic_data_plane_connection_count += 1
self._connected = True
async def close(self):
"""Disconnect from both request and data planes."""
global _g_singletonic_data_plane
global _g_singletonic_data_plane_connection_count
assert _g_singletonic_data_plane
await self._request_plane.close()
_g_singletonic_data_plane_connection_count -= 1
if _g_singletonic_data_plane_connection_count == 0:
_g_singletonic_data_plane.close()
_g_singletonic_data_plane = None
self._data_plane.close()
self._connected = False
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio
import json
import typing
from typing import Any, Coroutine, List, Optional
import numpy as np
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker.remote_operator import RemoteOperator
from triton_distributed.worker.remote_response import AsyncRemoteResponseIterator
from triton_distributed.worker.remote_tensor import RemoteTensor
from .connector import BaseTriton3Connector, InferenceRequest, InferenceResponse
from .remote_connector import RemoteConnector
class RemoteModelConnector(BaseTriton3Connector):
"""Connector for Triton 3 model."""
def __init__(
self,
request_plane: RequestPlane,
model_name: str,
model_version: Optional[str] = None,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize Triton 3 connector.
Args:
nats_url: NATS URL (e.g. "localhost:4222").
model_name: Model name.
model_version: Model version. Default is "1".
data_plane_host: Data plane host (e.g. "localhost").
data_plane_port: Data plane port (e.g. 8001). You can use 0 to let the system choose a port.
keep_dataplane_endpoints_open: Keep data plane endpoints open to avoid reconnecting. Default is False.
Example:
remote_model_connector = RemoteModelConnector(
nats_url="localhost:4222",
data_plane_host="localhost",
data_plane_port=8001,
model_name="model_name",
model_version="1",
)
async with remote_model_connector:
request = InferenceRequest(inputs={"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
async for response in remote_model_connector.inference(model_name="model_name", request=request):
print(response.outputs)
"""
self._model = None
self._connector = RemoteConnector(
request_plane,
data_plane_host,
data_plane_port,
keep_dataplane_endpoints_open=keep_dataplane_endpoints_open,
)
self._model_name = model_name
if model_version is None:
model_version = "1"
self._model_version = model_version
async def connect(self):
"""Connect to Triton 3 server."""
await self._connector.connect()
self._model = RemoteOperator(
operator=self._model_name,
request_plane=self._connector._request_plane,
data_plane=self._connector._data_plane,
)
async def close(self):
"""Disconnect from Triton 3 server."""
await self._connector.close()
self._model = None
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
async def inference(
self, model_name: str, request: InferenceRequest
) -> typing.AsyncGenerator[InferenceResponse, None]:
"""Inference request to Triton 3 system.
Args:
model_name: Model name.
request: Inference request.
Returns:
Inference response.
Raises:
TritonInferenceError: error occurred during inference.
"""
if not self._connector._connected or self._model is None:
await self.connect()
else:
if self._model_name != model_name:
self._model_name = model_name
self._model_version = "1"
self._model = RemoteOperator(
operator=self._model_name,
request_plane=self._connector._request_plane,
data_plane=self._connector._data_plane,
)
results: List[Coroutine[Any, Any, AsyncRemoteResponseIterator]] = []
for key, value in request.parameters.items():
if isinstance(value, dict):
request.parameters[key] = "JSON:" + json.dumps(value)
assert self._model is not None
results.append(
self._model.async_infer(
inputs=request.inputs,
parameters=request.parameters,
)
)
for result in asyncio.as_completed(results):
responses = await result
async for response in responses:
triton_response = response.to_model_infer_response(
self._connector._data_plane
)
outputs = {}
for output in triton_response.outputs:
remote_tensor = RemoteTensor(output, self._connector._data_plane)
try:
local_tensor = remote_tensor.local_tensor
numpy_tensor = np.from_dlpack(local_tensor)
finally:
# FIXME: This is a workaround for the issue that the remote tensor
# is released after connection is closed.
remote_tensor.__del__()
outputs[output.name] = numpy_tensor
infer_response = InferenceResponse(
outputs=outputs,
final=response.final,
parameters=response.parameters,
)
yield infer_response
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional, Tuple
import numpy as np
from pydantic import BaseModel
from tritonserver import Tensor as TritonTensor
from tritonserver._api._response import InferenceResponse as TritonInferenceResponse
from triton_distributed.icp.request_plane import RequestPlane
from triton_distributed.worker.remote_request import RemoteInferenceRequest
from triton_distributed.worker.remote_response import RemoteInferenceResponse
from .remote_connector import RemoteConnector
class LocalModel(BaseModel):
name: str
version: str
class RequestConverter:
"""Request converter. Class converts requests to convenient format for processing."""
def __init__(
self,
request_plane: RequestPlane,
model_name: str,
data_plane_host: Optional[str] = None,
data_plane_port: int = 0,
keep_dataplane_endpoints_open: bool = False,
):
"""Initialize RequestAdapter.
Args:
nats_url: NATS URL (e.g. "localhost:4222").
data_plane_host: Data plane host (e.g. "localhost").
data_plane_port: Data plane port (e.g. 8001). You can use 0 to let the system choose a port.
keep_dataplane_endpoints_open: Keep data plane endpoints open to avoid reconnecting. Default is False.
Example for async model:
worker = RequestConverter(
nats_url="localhost:4222",
data_plane_host="localhost",
data_plane_port=8001,
)
async with worker:
# This flow will process 10 requests at a time
processors = []
async def processing(request, callable):
request, callable = await queue.get()
inputs = request["inputs"]
parameters = request["parameters"]
output_tensor = inputs["a"] + inputs["b"]
try:
await callable({"c": output_tensor})
for _ in range(parameters["increment"]):
output_tensor += 1
await callable({"c": output_tensor})
finally:
await callable({"c": output_tensor}, final=True)
async for request, callable in worker.pull(model_name="model_name", batch_size=10):
# Check if batch size was reached
if len(processors) >= 10:
done, pending = asyncio.wait(processors, return_when=asyncio.FIRST_COMPLETED)
processors = list(pending)
processors.append(processing(request, callable))
"""
self._connector = RemoteConnector(
request_plane,
data_plane_host,
data_plane_port,
keep_dataplane_endpoints_open=keep_dataplane_endpoints_open,
)
self._local_model = LocalModel(name=model_name, version="1")
async def connect(self):
"""Connect to Triton 3 server."""
await self._connector.connect()
async def close(self):
"""Disconnect from Triton 3 server."""
await self._connector.close()
async def __aenter__(self):
"""Enter context manager."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
"""Exit context manager."""
await self.close()
async def pull(
self,
model_name: str,
model_version: Optional[str] = None,
batch_size: Optional[int] = None,
timeout: Optional[float] = None,
) -> AsyncGenerator[
Tuple[Dict[str, Any], Callable[[Dict[str, Any]], Awaitable[None]]], None
]:
"""Pull requests from request plane and data plane.
Pull returns an async generator that yields a tuple of request and callable.
Request contains inputs and parameters. Inputs are a dictionary of input names and numpy arrays. Parameters are
a dictionary of scalar parameters like sampling parameters in language models.
Callable is a function that takes outputs, error and final as arguments. Outputs are a dictionary of output names
and numpy arrays. Error is Exception. Final is a boolean that indicates if the response is final.
Args:
model_name: Model name.
model_version: Model version. Default is "1".
batch_size: Batch size. Default is 1.
timeout: Max duration of the pull request before it expires. Default is None.
Returns:
Inference request and callable.
Example:
worker = PythonWorkerConnector(
nats_url="localhost:4222",
data_plane_host="localhost",
data_plane_port=8001,
)
asyn with worker:
# This flow will process single request at a time
async for request, callable in worker.pull(model_name="model_name"):
# This is siple add model with incrementing the input tensor by increment parameter
inputs = request["inputs"]
parameters = request["parameters"]
output_tensor = inputs["a"] + inputs["b"]
try:
await callable({"c": output_tensor})
for _ in range(parameters["increment"]):
output_tensor += 1
await callable({"c": output_tensor})
finally:
await callable({"c": output_tensor}, final=True)
"""
if not self._connector._connected:
await self.connect()
if model_version is None:
model_version = "1"
if batch_size is None:
batch_size = 1
local_model = LocalModel(
name=model_name,
version=model_version,
)
kwargs = {
"model_name": model_name,
"model_version": model_version,
"number_requests": batch_size,
}
if timeout is not None:
kwargs["timeout"] = timeout
while True:
requests_iterator = await self._connector._request_plane.pull_requests(
**kwargs
)
async for request in requests_iterator:
inputs, remote_request, return_callable = await self.adapt_request(
request, local_model
)
yield {
"inputs": inputs,
"parameters": remote_request.parameters,
}, return_callable
async def adapt_request(self, request, local_model: Optional[LocalModel] = None):
if local_model is None:
local_model = self._local_model
if isinstance(request, RemoteInferenceRequest):
remote_request = request
request = remote_request.to_model_infer_request()
else:
remote_request = RemoteInferenceRequest.from_model_infer_request(
request,
self._connector._data_plane,
self._connector._request_plane,
)
def produce_callable(request):
async def return_callable(
outputs: Dict[str, Any],
parameters: Optional[Dict[str, Any]] = None,
error: Optional[str] = None,
final: Optional[bool] = False,
) -> None:
request_id = request.parameters["icp_request_id"].string_param
infer_kwargs = {
"model": local_model,
"request_id": request_id,
}
if error is not None:
infer_kwargs["error"] = error
else:
outputs_tensors = {}
for name, value in outputs.items():
outputs_tensors[name] = TritonTensor.from_dlpack(value)
infer_kwargs["outputs"] = outputs_tensors
if final is not None:
infer_kwargs["final"] = final
if parameters is not None:
infer_kwargs["parameters"] = parameters
local_response = TritonInferenceResponse(**infer_kwargs)
remote_response = RemoteInferenceResponse.from_local_response(
local_response,
).to_model_infer_response(self._connector._data_plane)
# FIXME: This is a WAR for scenario where connector isn't
# connected when posting a response to request plane.
if not self._connector._connected:
await self.connect()
await self._connector._request_plane.post_response(
request,
remote_response,
)
return return_callable
return_callable = produce_callable(request)
inputs = {}
for name, input_tensor in remote_request.inputs.items():
local_tensor = input_tensor.local_tensor
numpy_tensor = np.from_dlpack(local_tensor)
input_tensor.__del__()
inputs[name] = numpy_tensor
for key, value in remote_request.parameters.items():
if isinstance(value, str) and value.startswith("JSON:"):
remote_request.parameters[key] = json.loads(value[5:])
return inputs, remote_request, return_callable
import asyncio
import enum
import logging
import os
from contextlib import nullcontext
import torch
from .connector import InferenceRequest
from .remote_model_connector import RemoteModelConnector
from .request_converter import RequestConverter
LOGGER = logging.getLogger(__name__)
class _ProfileState(enum.Enum):
NOT_STARTED = 0
STARTED = 1
STOPPED = 2
class PiplineStageExecutor:
def __init__(self, args, request_plane, stage, stage_name, next_stage_name=None):
self.args = args
self.stage = stage
self.stage_name = stage_name
self.is_context_stage = next_stage_name is not None
self.next_stage_name = next_stage_name
self.remote_model_connector = (
RemoteModelConnector(
request_plane=request_plane,
model_name=self.next_stage_name,
keep_dataplane_endpoints_open=True,
)
if self.is_context_stage
else None
)
self.request_converter = RequestConverter(
request_plane=request_plane,
keep_dataplane_endpoints_open=True,
model_name=self.args.worker_name,
)
self.request_counter = 0
self.profile_state = _ProfileState.NOT_STARTED
self.tasks = []
async def baseline_process(self, request, return_result):
try:
LOGGER.debug("Processing request")
async for response in self.stage(request):
LOGGER.debug("Sending response")
await return_result(**response)
LOGGER.debug("Response send")
except Exception as e:
LOGGER.error(f"Error processing request: {e}")
await return_result({"error": e, "final": True})
LOGGER.debug("Processing finished")
async def process(self, request, return_result):
LOGGER.debug("Processing request")
try:
LOGGER.debug(f"Stage {self.stage_name} execution")
responses = list([response async for response in self.stage(request)])
LOGGER.debug(f"Stage {self.stage_name} finished")
assert len(responses) == 1
response = responses[0]
parameters = response.get("parameters", {})
if not parameters:
raise RuntimeError(
f"ERROR: Response parameters from stage {self.stage_name} should not be empty!"
)
outputs = response.get("outputs", {})
request = InferenceRequest(inputs=outputs, parameters=parameters)
LOGGER.info(f"Next stage {self.next_stage_name} execution")
assert self.remote_model_connector is not None
async for response in self.remote_model_connector.inference(
model_name=self.next_stage_name, request=request
):
LOGGER.debug(f"Stage {self.stage_name} sending response")
await return_result(
outputs=response.outputs,
final=response.final,
parameters={"text": response.parameters["text"]},
)
LOGGER.debug(f"Stage {self.stage_name} sended response")
except Exception as e:
LOGGER.error(f"Error processing request: {e}", exc_info=True)
await return_result(outputs={}, error=e, final=True)
async def handle_pipelined_requests(self):
LOGGER.info(
f"Start handling requests stage_name {self.stage_name} args {self.args}"
)
async with self.request_converter, self.remote_model_connector or nullcontext():
LOGGER.info(f"Stage {self.stage_name} starts pulling")
async for request, return_result in self.request_converter.pull(
model_name=self.args.worker_name
):
# TODO ptarasiewicz - only one context or generate should be profiled at a time
await self.process_request(request, return_result)
LOGGER.info(f"Stage {self.stage_name} finished pulling")
async def process_requests(self, requests):
for raw_request in requests:
(
inputs,
remote_request,
return_callable,
) = await self.request_converter.adapt_request(raw_request)
request, return_result = {
"inputs": inputs,
"parameters": remote_request.parameters,
}, return_callable
await self.process_request(request, return_result)
async def process_request(self, request, return_result):
self._profile()
if self.is_context_stage:
process_function = self.process
else:
process_function = self.baseline_process
# self.request_counter += 1
LOGGER.debug(f"Stage {self.stage_name} pulled request")
self.tasks.append(asyncio.create_task(process_function(request, return_result)))
if len(self.tasks) >= self.args.max_batch_size:
LOGGER.debug(
f"Stage {self.stage_name} waiting some of {len(self.tasks)} requests to finish"
)
_, pending = await asyncio.wait(
self.tasks, return_when=asyncio.FIRST_COMPLETED
)
self.tasks = list(pending)
LOGGER.debug(
f"Stage {self.stage_name} finished some requests with {len(self.tasks)} to do"
)
def _profile(self):
if os.environ.get("RUN_PROFILING") == "1":
if (
self.profile_state == _ProfileState.NOT_STARTED
and self.request_counter > 100
):
LOGGER.info("Start profiling")
torch.cuda.profiler.start()
self.profile_state = _ProfileState.STARTED
elif (
self.profile_state == _ProfileState.STARTED
and self.request_counter > 120
):
LOGGER.info("Stop profiling")
torch.cuda.profiler.stop()
self.profile_state = _ProfileState.STOPPED
# can also use with torch.cuda.profiler.profile():
...@@ -112,20 +112,27 @@ class NatsServer: ...@@ -112,20 +112,27 @@ class NatsServer:
print(command) print(command)
return return
os.makedirs(log_dir, exist_ok=True)
if clear_store: if clear_store:
shutil.rmtree(store_dir, ignore_errors=True) shutil.rmtree(store_dir, ignore_errors=True)
with open(f"{log_dir}/nats_server.stdout.log", "wt") as output_: if log_dir:
with open(f"{log_dir}/nats_server.stderr.log", "wt") as output_err: os.makedirs(log_dir, exist_ok=True)
process = subprocess.Popen(
command, with open(f"{log_dir}/nats_server.stdout.log", "wt") as output_:
stdin=subprocess.DEVNULL, with open(f"{log_dir}/nats_server.stderr.log", "wt") as output_err:
stdout=output_, process = subprocess.Popen(
stderr=output_err, command,
) stdin=subprocess.DEVNULL,
self._process = process stdout=output_,
stderr=output_err,
)
self._process = process
else:
process = subprocess.Popen(
command,
stdin=subprocess.DEVNULL,
)
self._process = process
def __del__(self): def __del__(self):
if self._process: if self._process:
...@@ -196,7 +203,9 @@ class NatsRequestPlane(RequestPlane): ...@@ -196,7 +203,9 @@ class NatsRequestPlane(RequestPlane):
Optional[nats.js.JetStreamContext.PullSubscription], Optional[nats.js.JetStreamContext.PullSubscription],
]: ]:
if self._jet_stream is None: if self._jet_stream is None:
raise InvalidArgumentError("Not Connected!") raise InvalidArgumentError(
"Failed to get model stream: NATS Jetstream not connected!"
)
if (model_name, model_version) in self._model_streams: if (model_name, model_version) in self._model_streams:
return self._model_streams[(model_name, model_version)] return self._model_streams[(model_name, model_version)]
...@@ -326,7 +335,9 @@ class NatsRequestPlane(RequestPlane): ...@@ -326,7 +335,9 @@ class NatsRequestPlane(RequestPlane):
responses: AsyncIterator[ModelInferResponse] | ModelInferResponse, responses: AsyncIterator[ModelInferResponse] | ModelInferResponse,
): ):
if self._jet_stream is None: if self._jet_stream is None:
raise InvalidArgumentError("Not Connected!") raise InvalidArgumentError(
"Failed to post response: NATS Jetstream not connected!"
)
request_id = get_icp_request_id(request) request_id = get_icp_request_id(request)
if request_id is None: if request_id is None:
...@@ -367,7 +378,9 @@ class NatsRequestPlane(RequestPlane): ...@@ -367,7 +378,9 @@ class NatsRequestPlane(RequestPlane):
] = None, ] = None,
) -> AsyncIterator[ModelInferResponse]: ) -> AsyncIterator[ModelInferResponse]:
if self._jet_stream is None: if self._jet_stream is None:
raise InvalidArgumentError("Not Connected!") raise InvalidArgumentError(
"Failed to post request: NATS Jetstream not connected!"
)
if response_iterator and response_handler: if response_iterator and response_handler:
raise InvalidArgumentError( raise InvalidArgumentError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment