Commit 52f97799 authored by laibao's avatar laibao
Browse files

更新README.md,修改Docker镜像版本,更新深度学习库版本,调整推理示例命令,删除不再使用的示例文件。

parent 35cc3501
# Helm Charts
This directory contains a Helm chart for deploying the vllm application. The chart includes configurations for deployment, autoscaling, resource management, and more.
## Files
- Chart.yaml: Defines the chart metadata including name, version, and maintainers.
- ct.yaml: Configuration for chart testing.
- lintconf.yaml: Linting rules for YAML files.
- values.schema.json: JSON schema for validating values.yaml.
- values.yaml: Default values for the Helm chart.
- templates/_helpers.tpl: Helper templates for defining common configurations.
- templates/configmap.yaml: Template for creating ConfigMaps.
- templates/custom-objects.yaml: Template for custom Kubernetes objects.
- templates/deployment.yaml: Template for creating Deployments.
- templates/hpa.yaml: Template for Horizontal Pod Autoscaler.
- templates/job.yaml: Template for Kubernetes Jobs.
- templates/poddisruptionbudget.yaml: Template for Pod Disruption Budget.
- templates/pvc.yaml: Template for Persistent Volume Claims.
- templates/secrets.yaml: Template for Kubernetes Secrets.
- templates/service.yaml: Template for creating Services.
chart-dirs:
- charts
validate-maintainers: false
\ No newline at end of file
---
rules:
braces:
min-spaces-inside: 0
max-spaces-inside: 0
min-spaces-inside-empty: -1
max-spaces-inside-empty: -1
brackets:
min-spaces-inside: 0
max-spaces-inside: 0
min-spaces-inside-empty: -1
max-spaces-inside-empty: -1
colons:
max-spaces-before: 0
max-spaces-after: 1
commas:
max-spaces-before: 0
min-spaces-after: 1
max-spaces-after: 1
comments:
require-starting-space: true
min-spaces-from-content: 2
document-end: disable
document-start: disable # No --- to start a file
empty-lines:
max: 2
max-start: 0
max-end: 0
hyphens:
max-spaces-after: 1
indentation:
spaces: consistent
indent-sequences: whatever # - list indentation will handle both indentation and without
check-multi-line-strings: false
key-duplicates: enable
line-length: disable # Lines can be any length
new-line-at-end-of-file: disable
new-lines:
type: unix
trailing-spaces: enable
truthy:
level: warning
\ No newline at end of file
{{/*
Define ports for the pods
*/}}
{{- define "chart.container-port" -}}
{{- default "8000" .Values.containerPort }}
{{- end }}
{{/*
Define service name
*/}}
{{- define "chart.service-name" -}}
{{- if .Values.serviceName }}
{{- .Values.serviceName | lower | trim }}
{{- else }}
"{{ .Release.Name }}-service"
{{- end }}
{{- end }}
{{/*
Define service port
*/}}
{{- define "chart.service-port" -}}
{{- if .Values.servicePort }}
{{- .Values.servicePort }}
{{- else }}
{{- include "chart.container-port" . }}
{{- end }}
{{- end }}
{{/*
Define service port name
*/}}
{{- define "chart.service-port-name" -}}
"service-port"
{{- end }}
{{/*
Define container port name
*/}}
{{- define "chart.container-port-name" -}}
"container-port"
{{- end }}
{{/*
Define deployment strategy
*/}}
{{- define "chart.strategy" -}}
strategy:
{{- if not .Values.deploymentStrategy }}
rollingUpdate:
maxSurge: 100%
maxUnavailable: 0
{{- else }}
{{ toYaml .Values.deploymentStrategy | indent 2 }}
{{- end }}
{{- end }}
{{/*
Define additional ports
*/}}
{{- define "chart.extraPorts" }}
{{- with .Values.extraPorts }}
{{ toYaml . }}
{{- end }}
{{- end }}
{{/*
Define chart external ConfigMaps and Secrets
*/}}
{{- define "chart.externalConfigs" -}}
{{- with .Values.externalConfigs -}}
{{ toYaml . }}
{{- end }}
{{- end }}
{{/*
Define liveness et readiness probes
*/}}
{{- define "chart.probes" -}}
{{- if .Values.readinessProbe }}
readinessProbe:
{{- with .Values.readinessProbe }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end }}
{{- if .Values.livenessProbe }}
livenessProbe:
{{- with .Values.livenessProbe }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Define resources
*/}}
{{- define "chart.resources" -}}
requests:
memory: {{ required "Value 'resources.requests.memory' must be defined !" .Values.resources.requests.memory | quote }}
cpu: {{ required "Value 'resources.requests.cpu' must be defined !" .Values.resources.requests.cpu | quote }}
{{- if and (gt (int (index .Values.resources.requests "nvidia.com/gpu")) 0) (gt (int (index .Values.resources.limits "nvidia.com/gpu")) 0) }}
nvidia.com/gpu: {{ required "Value 'resources.requests.nvidia.com/gpu' must be defined !" (index .Values.resources.requests "nvidia.com/gpu") | quote }}
{{- end }}
limits:
memory: {{ required "Value 'resources.limits.memory' must be defined !" .Values.resources.limits.memory | quote }}
cpu: {{ required "Value 'resources.limits.cpu' must be defined !" .Values.resources.limits.cpu | quote }}
{{- if and (gt (int (index .Values.resources.requests "nvidia.com/gpu")) 0) (gt (int (index .Values.resources.limits "nvidia.com/gpu")) 0) }}
nvidia.com/gpu: {{ required "Value 'resources.limits.nvidia.com/gpu' must be defined !" (index .Values.resources.limits "nvidia.com/gpu") | quote }}
{{- end }}
{{- end }}
{{/*
Define User used for the main container
*/}}
{{- define "chart.user" }}
{{- if .Values.image.runAsUser }}
runAsUser:
{{- with .Values.runAsUser }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end }}
{{- end }}
{{- define "chart.extraInitImage" -}}
"amazon/aws-cli:2.6.4"
{{- end }}
{{- define "chart.extraInitEnv" -}}
- name: S3_ENDPOINT_URL
valueFrom:
secretKeyRef:
name: {{ .Release.Name }}-secrets
key: s3endpoint
- name: S3_BUCKET_NAME
valueFrom:
secretKeyRef:
name: {{ .Release.Name }}-secrets
key: s3bucketname
- name: AWS_ACCESS_KEY_ID
valueFrom:
secretKeyRef:
name: {{ .Release.Name }}-secrets
key: s3accesskeyid
- name: AWS_SECRET_ACCESS_KEY
valueFrom:
secretKeyRef:
name: {{ .Release.Name }}-secrets
key: s3accesskey
- name: S3_PATH
value: "{{ .Values.extraInit.s3modelpath }}"
- name: AWS_EC2_METADATA_DISABLED
value: "{{ .Values.extraInit.awsEc2MetadataDisabled }}"
{{- end }}
{{/*
Define chart labels
*/}}
{{- define "chart.labels" -}}
{{- with .Values.labels -}}
{{ toYaml . }}
{{- end }}
{{- end }}
\ No newline at end of file
{{- if .Values.configs -}}
apiVersion: v1
kind: ConfigMap
metadata:
name: "{{ .Release.Name }}-configs"
namespace: {{ .Release.Namespace }}
data:
{{- with .Values.configs }}
{{- toYaml . | nindent 2 }}
{{- end }}
{{- end -}}
\ No newline at end of file
{{- if .Values.customObjects }}
{{- range .Values.customObjects }}
{{- tpl (. | toYaml) $ }}
---
{{- end }}
{{- end }}
\ No newline at end of file
apiVersion: apps/v1
kind: Deployment
metadata:
name: "{{ .Release.Name }}-deployment-vllm"
namespace: {{ .Release.Namespace }}
labels:
{{- include "chart.labels" . | nindent 4 }}
spec:
replicas: {{ .Values.replicaCount }}
{{- include "chart.strategy" . | nindent 2 }}
selector:
matchLabels:
environment: "test"
release: "test"
progressDeadlineSeconds: 1200
template:
metadata:
labels:
environment: "test"
release: "test"
spec:
containers:
- name: "vllm"
image: "{{ required "Required value 'image.repository' must be defined !" .Values.image.repository }}:{{ required "Required value 'image.tag' must be defined !" .Values.image.tag }}"
{{- if .Values.image.command }}
command :
{{- with .Values.image.command }}
{{- toYaml . | nindent 10 }}
{{- end }}
{{- end }}
securityContext:
{{- if .Values.image.securityContext }}
{{- with .Values.image.securityContext }}
{{- toYaml . | nindent 12 }}
{{- end }}
{{- else }}
runAsNonRoot: false
{{- include "chart.user" . | indent 12 }}
{{- end }}
imagePullPolicy: IfNotPresent
{{- if .Values.image.env }}
env :
{{- with .Values.image.env }}
{{- toYaml . | nindent 10 }}
{{- end }}
{{- else }}
env: []
{{- end }}
{{- if or .Values.externalConfigs .Values.configs .Values.secrets }}
envFrom:
{{- if .Values.configs }}
- configMapRef:
name: "{{ .Release.Name }}-configs"
{{- end }}
{{- if .Values.secrets}}
- secretRef:
name: "{{ .Release.Name }}-secrets"
{{- end }}
{{- include "chart.externalConfigs" . | nindent 12 }}
{{- end }}
ports:
- name: {{ include "chart.container-port-name" . }}
containerPort: {{ include "chart.container-port" . }}
{{- include "chart.extraPorts" . | nindent 12 }}
{{- include "chart.probes" . | indent 10 }}
resources: {{- include "chart.resources" . | nindent 12 }}
volumeMounts:
- name: {{ .Release.Name }}-storage
mountPath: /data
{{- with .Values.extraContainers }}
{{ toYaml . | nindent 8 }}
{{- end }}
{{- if .Values.extraInit }}
initContainers:
- name: wait-download-model
image: {{ include "chart.extraInitImage" . }}
command:
- /bin/bash
args:
- -eucx
- while aws --endpoint-url $S3_ENDPOINT_URL s3 sync --dryrun s3://$S3_BUCKET_NAME/$S3_PATH /data | grep -q download; do sleep 10; done
env: {{- include "chart.extraInitEnv" . | nindent 10 }}
resources:
requests:
cpu: 200m
memory: 1Gi
limits:
cpu: 500m
memory: 2Gi
volumeMounts:
- name: {{ .Release.Name }}-storage
mountPath: /data
{{- end }}
volumes:
- name: {{ .Release.Name }}-storage
persistentVolumeClaim:
claimName: {{ .Release.Name }}-storage-claim
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- if and (gt (int (index .Values.resources.requests "nvidia.com/gpu")) 0) (gt (int (index .Values.resources.limits "nvidia.com/gpu")) 0) }}
runtimeClassName: nvidia
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: nvidia.com/gpu.product
operator: In
{{- with .Values.gpuModels }}
values:
{{- toYaml . | nindent 20 }}
{{- end }}
{{- end }}
\ No newline at end of file
{{- if .Values.autoscaling.enabled }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: "{{ .Release.Name }}-hpa"
namespace: {{ .Release.Namespace }}
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: vllm
minReplicas: {{ .Values.autoscaling.minReplicas }}
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
metrics:
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
{{- end }}
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
{{- end }}
{{- end }}
\ No newline at end of file
{{- if .Values.extraInit }}
apiVersion: batch/v1
kind: Job
metadata:
name: "{{ .Release.Name }}-init-vllm"
namespace: {{ .Release.Namespace }}
spec:
ttlSecondsAfterFinished: 100
template:
metadata:
name: init-vllm
spec:
containers:
- name: job-download-model
image: {{ include "chart.extraInitImage" . }}
command:
- /bin/bash
args:
- -eucx
- aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data
env: {{- include "chart.extraInitEnv" . | nindent 8 }}
volumeMounts:
- name: {{ .Release.Name }}-storage
mountPath: /data
resources:
requests:
cpu: 200m
memory: 1Gi
limits:
cpu: 500m
memory: 2Gi
restartPolicy: OnFailure
volumes:
- name: {{ .Release.Name }}-storage
persistentVolumeClaim:
claimName: "{{ .Release.Name }}-storage-claim"
{{- end }}
\ No newline at end of file
apiVersion: policy/v1
kind: PodDisruptionBudget
metadata:
name: "{{ .Release.Name }}-pdb"
namespace: {{ .Release.Namespace }}
spec:
maxUnavailable: {{ default 1 .Values.maxUnavailablePodDisruptionBudget }}
\ No newline at end of file
{{- if .Values.extraInit }}
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: "{{ .Release.Name }}-storage-claim"
namespace: {{ .Release.Namespace }}
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: {{ .Values.extraInit.pvcStorage }}
{{- end }}
\ No newline at end of file
apiVersion: v1
kind: Secret
metadata:
name: "{{ .Release.Name }}-secrets"
namespace: {{ .Release.Namespace }}
type: Opaque
data:
{{- range $key, $val := .Values.secrets }}
{{ $key }}: {{ $val | b64enc | quote }}
{{- end }}
\ No newline at end of file
apiVersion: v1
kind: Service
metadata:
name: "{{ .Release.Name }}-service"
namespace: {{ .Release.Namespace }}
spec:
type: ClusterIP
ports:
- name: {{ include "chart.service-port-name" . }}
port: {{ include "chart.service-port" . }}
targetPort: {{ include "chart.container-port-name" . }}
protocol: TCP
selector:
{{- include "chart.labels" . | nindent 4 }}
\ No newline at end of file
{
"$schema": "http://json-schema.org/schema#",
"type": "object",
"properties": {
"image": {
"type": "object",
"properties": {
"repository": {
"type": "string"
},
"tag": {
"type": "string"
},
"command": {
"type": "array",
"items": {
"type": "string"
}
}
},
"required": [
"command",
"repository",
"tag"
]
},
"containerPort": {
"type": "integer"
},
"serviceName": {
"type": "null"
},
"servicePort": {
"type": "integer"
},
"extraPorts": {
"type": "array"
},
"replicaCount": {
"type": "integer"
},
"deploymentStrategy": {
"type": "object"
},
"resources": {
"type": "object",
"properties": {
"requests": {
"type": "object",
"properties": {
"cpu": {
"type": "integer"
},
"memory": {
"type": "string"
},
"nvidia.com/gpu": {
"type": "integer"
}
},
"required": [
"cpu",
"memory",
"nvidia.com/gpu"
]
},
"limits": {
"type": "object",
"properties": {
"cpu": {
"type": "integer"
},
"memory": {
"type": "string"
},
"nvidia.com/gpu": {
"type": "integer"
}
},
"required": [
"cpu",
"memory",
"nvidia.com/gpu"
]
}
},
"required": [
"limits",
"requests"
]
},
"gpuModels": {
"type": "array",
"items": {
"type": "string"
}
},
"autoscaling": {
"type": "object",
"properties": {
"enabled": {
"type": "boolean"
},
"minReplicas": {
"type": "integer"
},
"maxReplicas": {
"type": "integer"
},
"targetCPUUtilizationPercentage": {
"type": "integer"
}
},
"required": [
"enabled",
"maxReplicas",
"minReplicas",
"targetCPUUtilizationPercentage"
]
},
"configs": {
"type": "object"
},
"secrets": {
"type": "object"
},
"externalConfigs": {
"type": "array"
},
"customObjects": {
"type": "array"
},
"maxUnavailablePodDisruptionBudget": {
"type": "string"
},
"extraInit": {
"type": "object",
"properties": {
"s3modelpath": {
"type": "string"
},
"pvcStorage": {
"type": "string"
},
"awsEc2MetadataDisabled": {
"type": "boolean"
}
},
"required": [
"pvcStorage",
"s3modelpath",
"awsEc2MetadataDisabled"
]
},
"extraContainers": {
"type": "array"
},
"readinessProbe": {
"type": "object",
"properties": {
"initialDelaySeconds": {
"type": "integer"
},
"periodSeconds": {
"type": "integer"
},
"failureThreshold": {
"type": "integer"
},
"httpGet": {
"type": "object",
"properties": {
"path": {
"type": "string"
},
"port": {
"type": "integer"
}
},
"required": [
"path",
"port"
]
}
},
"required": [
"failureThreshold",
"httpGet",
"initialDelaySeconds",
"periodSeconds"
]
},
"livenessProbe": {
"type": "object",
"properties": {
"initialDelaySeconds": {
"type": "integer"
},
"failureThreshold": {
"type": "integer"
},
"periodSeconds": {
"type": "integer"
},
"httpGet": {
"type": "object",
"properties": {
"path": {
"type": "string"
},
"port": {
"type": "integer"
}
},
"required": [
"path",
"port"
]
}
},
"required": [
"failureThreshold",
"httpGet",
"initialDelaySeconds",
"periodSeconds"
]
},
"labels": {
"type": "object",
"properties": {
"environment": {
"type": "string"
},
"release": {
"type": "string"
}
},
"required": [
"environment",
"release"
]
}
},
"required": [
"autoscaling",
"configs",
"containerPort",
"customObjects",
"deploymentStrategy",
"externalConfigs",
"extraContainers",
"extraInit",
"extraPorts",
"gpuModels",
"image",
"labels",
"livenessProbe",
"maxUnavailablePodDisruptionBudget",
"readinessProbe",
"replicaCount",
"resources",
"secrets",
"servicePort"
]
}
\ No newline at end of file
# -- Default values for chart vllm
# -- Declare variables to be passed into your templates.
# -- Image configuration
image:
# -- Image repository
repository: "vllm/vllm-openai"
# -- Image tag
tag: "latest"
# -- Container launch command
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
# -- Container port
containerPort: 8000
# -- Service name
serviceName:
# -- Service port
servicePort: 80
# -- Additional ports configuration
extraPorts: []
# -- Number of replicas
replicaCount: 1
# -- Deployment strategy configuration
deploymentStrategy: {}
# -- Resource configuration
resources:
requests:
# -- Number of CPUs
cpu: 4
# -- CPU memory configuration
memory: 16Gi
# -- Number of gpus used
nvidia.com/gpu: 1
limits:
# -- Number of CPUs
cpu: 4
# -- CPU memory configuration
memory: 16Gi
# -- Number of gpus used
nvidia.com/gpu: 1
# -- Type of gpu used
gpuModels:
- "TYPE_GPU_USED"
# -- Autoscaling configuration
autoscaling:
# -- Enable autoscaling
enabled: false
# -- Minimum replicas
minReplicas: 1
# -- Maximum replicas
maxReplicas: 100
# -- Target CPU utilization for autoscaling
targetCPUUtilizationPercentage: 80
# targetMemoryUtilizationPercentage: 80
# -- Configmap
configs: {}
# -- Secrets configuration
secrets: {}
# -- External configuration
externalConfigs: []
# -- Custom Objects configuration
customObjects: []
# -- Disruption Budget Configuration
maxUnavailablePodDisruptionBudget: ""
# -- Additional configuration for the init container
extraInit:
# -- Path of the model on the s3 which hosts model weights and config files
s3modelpath: "relative_s3_model_path/opt-125m"
# -- Storage size of the s3
pvcStorage: "1Gi"
awsEc2MetadataDisabled: true
# -- Additional containers configuration
extraContainers: []
# -- Readiness probe configuration
readinessProbe:
# -- Number of seconds after the container has started before readiness probe is initiated
initialDelaySeconds: 5
# -- How often (in seconds) to perform the readiness probe
periodSeconds: 5
# -- Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not ready
failureThreshold: 3
# -- Configuration of the Kubelet http request on the server
httpGet:
# -- Path to access on the HTTP server
path: /health
# -- Name or number of the port to access on the container, on which the server is listening
port: 8000
# -- Liveness probe configuration
livenessProbe:
# -- Number of seconds after the container has started before liveness probe is initiated
initialDelaySeconds: 15
# -- Number of times after which if a probe fails in a row, Kubernetes considers that the overall check has failed: the container is not alive
failureThreshold: 3
# -- How often (in seconds) to perform the liveness probe
periodSeconds: 10
# -- Configuration of the Kubelet http request on the server
httpGet:
# -- Path to access on the HTTP server
path: /health
# -- Name or number of the port to access on the container, on which the server is listening
port: 8000
labels:
environment: "test"
release: "test"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
the Cohere SDK: https://github.com/cohere-ai/cohere-python
Note that `pip install cohere` is needed to run this example.
run: vllm serve BAAI/bge-reranker-base
"""
from typing import Union
import cohere
from cohere import Client, ClientV2
model = "BAAI/bge-reranker-base"
query = "What is the capital of France?"
documents = [
"The capital of France is Paris",
"Reranking is fun!",
"vLLM is an open-source framework for fast AI serving",
]
def cohere_rerank(
client: Union[Client, ClientV2], model: str, query: str, documents: list[str]
) -> dict:
return client.rerank(model=model, query=query, documents=documents)
def main():
# cohere v1 client
cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
print("-" * 50)
print("rerank_v1_result:\n", rerank_v1_result)
print("-" * 50)
# or the v2
cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
print("rerank_v2_result:\n", rerank_v2_result)
print("-" * 50)
if __name__ == "__main__":
main()
#!/bin/bash
# This file demonstrates the example usage of disaggregated prefilling
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
# and then transfer the KV cache between them.
set -xe
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
sleep 1
# meta-llama/Meta-Llama-3.1-8B-Instruct or deepseek-ai/DeepSeek-V2-Lite
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi
# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# You can also adjust --kv-ip and --kv-port for distributed inference.
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
--port 8100 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
--port 8200 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200
# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (port 8100), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance
# NOTE: the usage of this API is subject to change --- in the future we will
# introduce "vllm connect" to connect between prefill and decode instances
python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py &
sleep 1
# serve two example requests
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "San Francisco is a",
"max_tokens": 10,
"temperature": 0
}')
output2=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Santa Clara is a",
"max_tokens": 10,
"temperature": 0
}')
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo ""
sleep 1
# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""
# Disaggregated Serving
This example contains scripts that demonstrate the disaggregated serving features of vLLM.
## Files
- `disagg_proxy_demo.py` - Demonstrates XpYd (X prefill instances, Y decode instances).
- `kv_events.sh` - Demonstrates KV cache event publishing.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file provides a disaggregated prefilling proxy demo to demonstrate an
example usage of XpYd disaggregated prefilling.
We can launch multiple vllm instances (2 for prefill and 2 for decode), and
launch this proxy demo through:
python3 examples/online_serving/disaggregated_serving/disagg_proxy_demo.py \
--model $model_name \
--prefill localhost:8100 localhost:8101 \
--decode localhost:8200 localhost:8201 \
--port 8000
Note: This demo will be removed once the PDController implemented in PR 15343
(https://github.com/vllm-project/vllm/pull/15343) supports XpYd.
"""
import argparse
import ipaddress
import itertools
import json
import logging
import os
import sys
from abc import ABC, abstractmethod
from typing import Callable, Optional
import aiohttp
import requests
import uvicorn
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse, StreamingResponse
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
class SchedulingPolicy(ABC):
@abstractmethod
def schedule(self, cycler: itertools.cycle):
raise NotImplementedError("Scheduling Proxy is not set.")
class Proxy:
def __init__(
self,
prefill_instances: list[str],
decode_instances: list[str],
model: str,
scheduling_policy: SchedulingPolicy,
custom_create_completion: Optional[
Callable[[Request], StreamingResponse]
] = None,
custom_create_chat_completion: Optional[
Callable[[Request], StreamingResponse]
] = None,
):
self.prefill_instances = prefill_instances
self.decode_instances = decode_instances
self.prefill_cycler = itertools.cycle(prefill_instances)
self.decode_cycler = itertools.cycle(decode_instances)
self.model = model
self.scheduling_policy = scheduling_policy
self.custom_create_completion = custom_create_completion
self.custom_create_chat_completion = custom_create_chat_completion
self.router = APIRouter()
self.setup_routes()
def setup_routes(self):
self.router.post(
"/v1/completions", dependencies=[Depends(self.validate_json_request)]
)(
self.custom_create_completion
if self.custom_create_completion
else self.create_completion
)
self.router.post(
"/v1/chat/completions", dependencies=[Depends(self.validate_json_request)]
)(
self.custom_create_chat_completion
if self.custom_create_chat_completion
else self.create_chat_completion
)
self.router.get("/status", response_class=JSONResponse)(self.get_status)
self.router.post(
"/instances/add", dependencies=[Depends(self.api_key_authenticate)]
)(self.add_instance_endpoint)
async def validate_json_request(self, raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
if content_type != "application/json":
raise HTTPException(
status_code=415,
detail="Unsupported Media Type: Only 'application/json' is allowed",
)
def api_key_authenticate(self, x_api_key: str = Header(...)):
expected_api_key = os.environ.get("ADMIN_API_KEY")
if not expected_api_key:
logger.error("ADMIN_API_KEY is not set in the environment.")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Server configuration error.",
)
if x_api_key != expected_api_key:
logger.warning("Unauthorized access attempt with API Key: %s", x_api_key)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Forbidden: Invalid API Key.",
)
async def validate_instance(self, instance: str) -> bool:
url = f"http://{instance}/v1/models"
try:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as client:
logger.info("Verifying %s ...", instance)
async with client.get(url) as response:
if response.status == 200:
data = await response.json()
if "data" in data and len(data["data"]) > 0:
model_cur = data["data"][0].get("id", "")
if model_cur == self.model:
logger.info("Instance: %s could be added.", instance)
return True
else:
logger.warning(
"Mismatch model %s : %s != %s",
instance,
model_cur,
self.model,
)
return False
else:
return False
else:
return False
except aiohttp.ClientError as e:
logger.error(str(e))
return False
except Exception as e:
logger.error(str(e))
return False
async def add_instance_endpoint(self, request: Request):
try:
data = await request.json()
logger.warning(str(data))
instance_type = data.get("type")
instance = data.get("instance")
if instance_type not in ["prefill", "decode"]:
raise HTTPException(status_code=400, detail="Invalid instance type.")
if not instance or ":" not in instance:
raise HTTPException(status_code=400, detail="Invalid instance format.")
host, port_str = instance.split(":")
try:
if host != "localhost":
ipaddress.ip_address(host)
port = int(port_str)
if not (0 < port < 65536):
raise HTTPException(status_code=400, detail="Invalid port number.")
except Exception as e:
raise HTTPException(
status_code=400, detail="Invalid instance address."
) from e
is_valid = await self.validate_instance(instance)
if not is_valid:
raise HTTPException(
status_code=400, detail="Instance validation failed."
)
if instance_type == "prefill":
if instance not in self.prefill_instances:
self.prefill_instances.append(instance)
self.prefill_cycler = itertools.cycle(self.prefill_instances)
else:
raise HTTPException(
status_code=400, detail="Instance already exists."
)
else:
if instance not in self.decode_instances:
self.decode_instances.append(instance)
self.decode_cycler = itertools.cycle(self.decode_instances)
else:
raise HTTPException(
status_code=400, detail="Instance already exists."
)
return JSONResponse(
content={"message": f"Added {instance} to {instance_type}_instances."}
)
except HTTPException as http_exc:
raise http_exc
except Exception as e:
logger.error("Error in add_instance_endpoint: %s", str(e))
raise HTTPException(status_code=500, detail=str(e)) from e
async def forward_request(self, url, data, use_chunked=True):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
try:
async with session.post(
url=url, json=data, headers=headers
) as response:
if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501
if use_chunked:
async for chunk_bytes in response.content.iter_chunked( # noqa: E501
1024
):
yield chunk_bytes
else:
content = await response.read()
yield content
else:
error_content = await response.text()
try:
error_content = json.loads(error_content)
except json.JSONDecodeError:
error_content = error_content
logger.error(
"Request failed with status %s: %s",
response.status,
error_content,
)
raise HTTPException(
status_code=response.status,
detail=f"Request failed with status {response.status}: "
f"{error_content}",
)
except aiohttp.ClientError as e:
logger.error("ClientError occurred: %s", str(e))
raise HTTPException(
status_code=502,
detail="Bad Gateway: Error communicating with upstream server.",
) from e
except Exception as e:
logger.error("Unexpected error: %s", str(e))
raise HTTPException(status_code=500, detail=str(e)) from e
def schedule(self, cycler: itertools.cycle) -> str:
return self.scheduling_policy.schedule(cycler)
async def get_status(self):
status = {
"prefill_node_count": len(self.prefill_instances),
"decode_node_count": len(self.decode_instances),
"prefill_nodes": self.prefill_instances,
"decode_nodes": self.decode_instances,
}
return status
async def create_completion(self, raw_request: Request):
try:
request = await raw_request.json()
kv_prepare_request = request.copy()
kv_prepare_request["max_tokens"] = 1
prefill_instance = self.schedule(self.prefill_cycler)
try:
async for _ in self.forward_request(
f"http://{prefill_instance}/v1/completions", kv_prepare_request
):
continue
except HTTPException as http_exc:
self.remove_instance_endpoint("prefill", prefill_instance)
raise http_exc
# Perform kv recv and decoding stage
decode_instance = self.schedule(self.decode_cycler)
try:
generator = self.forward_request(
f"http://{decode_instance}/v1/completions", request
)
except HTTPException as http_exc:
self.remove_instance_endpoint("decode", decode_instance)
raise http_exc
response = StreamingResponse(generator)
return response
except Exception:
import sys
exc_info = sys.exc_info()
print("Error occurred in disagg proxy server")
print(exc_info)
async def create_chat_completion(self, raw_request: Request):
try:
request = await raw_request.json()
# add params to request
kv_prepare_request = request.copy()
kv_prepare_request["max_tokens"] = 1
# prefill stage
prefill_instance = self.schedule(self.prefill_cycler)
try:
async for _ in self.forward_request(
f"http://{prefill_instance}/v1/chat/completions", kv_prepare_request
):
continue
except HTTPException as http_exc:
self.remove_instance_endpoint("prefill", prefill_instance)
raise http_exc
# Perform kv recv and decoding stage
decode_instance = self.schedule(self.decode_cycler)
try:
generator = self.forward_request(
"http://" + decode_instance + "/v1/chat/completions", request
)
except HTTPException as http_exc:
self.remove_instance_endpoint("decode", decode_instance)
raise http_exc
response = StreamingResponse(content=generator)
return response
except Exception:
exc_info = sys.exc_info()
error_messages = [str(e) for e in exc_info if e]
print("Error occurred in disagg proxy server")
print(error_messages)
return StreamingResponse(
content=iter(error_messages), media_type="text/event-stream"
)
def remove_instance_endpoint(self, instance_type, instance):
if instance_type == "decode" and instance in self.decode_instances:
self.decode_instances.remove(instance)
self.decode_cycler = itertools.cycle(self.decode_instances)
if instance_type == "prefill" and instance in self.decode_instances:
self.prefill_instances.remove(instance)
self.prefill_cycler = itertools.cycle(self.decode_instances)
class RoundRobinSchedulingPolicy(SchedulingPolicy):
def __init__(self):
super().__init__()
def schedule(self, cycler: itertools.cycle) -> str:
return next(cycler)
class ProxyServer:
def __init__(
self,
args: argparse.Namespace,
scheduling_policy: Optional[SchedulingPolicy] = None,
create_completion: Optional[Callable[[Request], StreamingResponse]] = None,
create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None,
):
self.validate_parsed_serve_args(args)
self.port = args.port
self.proxy_instance = Proxy(
prefill_instances=[] if args.prefill is None else args.prefill,
decode_instances=[] if args.decode is None else args.decode,
model=args.model,
scheduling_policy=(
scheduling_policy
if scheduling_policy is not None
else RoundRobinSchedulingPolicy()
),
custom_create_completion=create_completion,
custom_create_chat_completion=create_chat_completion,
)
def validate_parsed_serve_args(self, args: argparse.Namespace):
if not args.prefill:
raise ValueError("Please specify at least one prefill node.")
if not args.decode:
raise ValueError("Please specify at least one decode node.")
self.validate_instances(args.prefill)
self.validate_instances(args.decode)
self.verify_model_config(args.prefill, args.model)
self.verify_model_config(args.decode, args.model)
def validate_instances(self, instances: list):
for instance in instances:
if len(instance.split(":")) != 2:
raise ValueError(f"Invalid instance format: {instance}")
host, port = instance.split(":")
try:
if host != "localhost":
ipaddress.ip_address(host)
port = int(port)
if not (0 < port < 65536):
raise ValueError(f"Invalid port number in instance: {instance}")
except Exception as e:
raise ValueError(f"Invalid instance {instance}: {str(e)}") from e
def verify_model_config(self, instances: list, model: str) -> None:
model_suffix = model.split("/")[-1]
for instance in instances:
try:
response = requests.get(f"http://{instance}/v1/models")
if response.status_code == 200:
model_cur = response.json()["data"][0]["id"]
model_cur_suffix = model_cur.split("/")[-1]
if model_cur_suffix != model_suffix:
raise ValueError(
f"{instance} serves a different model: "
f"{model_cur} != {model}"
)
else:
raise ValueError(f"Cannot get model id from {instance}!")
except requests.RequestException as e:
raise ValueError(
f"Error communicating with {instance}: {str(e)}"
) from e
def run_server(self):
app = FastAPI()
app.include_router(self.proxy_instance.router)
config = uvicorn.Config(app, port=self.port, loop="uvloop")
server = uvicorn.Server(config)
server.run()
def parse_args():
# Todo: allow more config
parser = argparse.ArgumentParser("vLLM disaggregated proxy server.")
parser.add_argument("--model", "-m", type=str, required=True, help="Model name")
parser.add_argument(
"--prefill",
"-p",
type=str,
nargs="+",
help="List of prefill node URLs (host:port)",
)
parser.add_argument(
"--decode",
"-d",
type=str,
nargs="+",
help="List of decode node URLs (host:port)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Server port number",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
proxy_server = ProxyServer(args=args)
proxy_server.run_server()
#!/bin/bash
# This file demonstrates the KV cache event publishing
# We will launch a vllm instances configured to publish KV cache
# events and launch a simple subscriber to log those events.
set -xe
echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧"
sleep 1
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
vllm serve $MODEL_NAME \
--port 8100 \
--max-model-len 100 \
--enforce-eager \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-events-config \
'{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' &
wait_for_server 8100
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
python3 "$SCRIPT_DIR/kv_events_subscriber.py" &
sleep 1
# serve two example requests
output1=$(curl -X POST -s http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.",
"max_tokens": 80,
"temperature": 0
}')
output2=$(curl -X POST -s http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.",
"max_tokens": 80,
"temperature": 0
}')
# Cleanup commands
pkill -9 -u "$USER" -f python
pkill -9 -u "$USER" -f vllm
sleep 1
echo "Cleaned up"
# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment