Commit 60a73634 authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: use vllm out of process engine


Signed-off-by: default avatarPiotr Marcinkiewicz <piotrm@nvidia.com>
Co-authored-by: default avatarPiotr Marcinkiewicz <piotrm@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 65a2dfab
......@@ -94,7 +94,7 @@ VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=0 python3 -m disaggregat
--enforce-eager \
--tensor-parallel-size 1 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
'{"kv_connector":"TritonNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
```
**Terminal 2 - Decode Worker:**
......@@ -107,7 +107,7 @@ VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=1,2 python3 -m disaggreg
--enforce-eager \
--tensor-parallel-size 2 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
'{"kv_connector":"TritonNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
```
The disaggregated deployment utilizes separate GPUs for prefill and decode operations, allowing for optimized resource allocation and improved performance. For more details on the disaggregated deployment, please refer to the [vLLM documentation](https://docs.vllm.ai/en/latest/features/disagg_prefill.html).
......
......@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import vllm
from common.chat_processor import ChatProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
logger = logging.getLogger("vllm")
class BaseVllmEngine:
......@@ -26,11 +32,43 @@ class BaseVllmEngine:
"""
def __init__(self, engine_args: AsyncEngineArgs):
self.model_config = engine_args.create_model_config()
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.chat_processor = ChatProcessor(self.engine, self.model_config)
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.engine_client = None
self.chat_processor = None
self._engine_context = None
async def initialize(self):
"""Initialize the engine client and related components."""
print("Initializing engine client")
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
self.chat_processor = ChatProcessor(self.engine_client, self.model_config)
else:
raise RuntimeError("Failed to initialize engine client")
async def cleanup(self):
"""Cleanup resources."""
print("Cleaning up engine client")
if self._engine_context is not None:
await self._engine_context.__aexit__(None, None, None)
self._engine_context = None
self.engine_client = None
self.chat_processor = None
async def __aenter__(self):
await self.initialize()
"""Initialize with context manager syntax."""
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()
async def _parse_raw_request(self, raw_request):
assert self.engine_client is not None
request = self.chat_processor.parse_raw_request(raw_request)
(
conversation,
......@@ -49,6 +87,7 @@ class BaseVllmEngine:
return request, conversation, request_prompt, engine_prompt, sampling_params
async def _stream_response(self, request, generator, request_id, conversation):
assert self.engine_client is not None
return self.chat_processor.stream_response(
request,
generator,
......
......@@ -15,7 +15,7 @@
import asyncio
import random
import socket
import uuid
import msgspec
......@@ -37,23 +37,26 @@ class VllmDecodeEngine(BaseVllmEngine):
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs):
def __init__(self, engine_args: AsyncEngineArgs, prefill):
assert (
engine_args.kv_transfer_config.is_kv_consumer
), "Decode worker must be a KV consumer"
if engine_args.enable_chunked_prefill is not False:
vllm_logger.info(
"Chunked prefill is not supported in disaggregated mode, disabling it"
)
engine_args.enable_chunked_prefill = False
super().__init__(engine_args)
self.prefills: list = []
self.prefill = prefill
self.num_prefill_workers = (
self.engine.engine.vllm_config.kv_transfer_config.kv_producers_parallel_size
)
self.kv_rank = self.engine.engine.vllm_config.kv_transfer_config.kv_rank
def add_prefill(self, prefill):
self.prefills.append(prefill)
self.kv_transfer_config = engine_args.create_engine_config().kv_transfer_config
self.kv_rank = self.kv_transfer_config.kv_rank
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, raw_request):
if self.engine_client is None:
await self.initialize()
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
......@@ -62,25 +65,32 @@ class VllmDecodeEngine(BaseVllmEngine):
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
prefill_rank = random.choice(range(self.num_prefill_workers))
request_id = f"{uuid.uuid4()}___prefill_kv_rank_{prefill_rank}___decode_kv_rank_{self.kv_rank}"
# TODO: pass decode info through a separate request param
request_id = f"{uuid.uuid4()}___decode_hostname_{socket.gethostname()}___decode_kv_rank_{self.kv_rank}"
prefill_sampling_params = {**msgspec.to_builtins(sampling_params)}
prefill_sampling_params["max_tokens"] = 1
prefill_sampling_params["min_tokens"] = 1
prefill_request = PrefillRequest(
prompt=request_prompt, # TODO: we should use engine prompt to avoid extra tokenization
sampling_params=prefill_sampling_params,
request_id=request_id,
)
vllm_logger.debug(f"Prefill request: {prefill_request}")
self.prefills[prefill_rank].generate(
prefill_output = self.prefill.generate(
prefill_request.model_dump_json(),
)
vllm_logger.debug(
f"Running generate with engine_prompt: {engine_prompt}, sampling_params: {sampling_params}, request_id: {request_id}"
)
generator = self.engine.generate(engine_prompt, sampling_params, request_id)
if self.engine_client is None:
raise RuntimeError("Engine client not initialized")
else:
generator = self.engine_client.generate(
engine_prompt, sampling_params, request_id
)
async for response in await self._stream_response(
request, generator, request_id, conversation
......@@ -88,6 +98,8 @@ class VllmDecodeEngine(BaseVllmEngine):
vllm_logger.debug(f"Generated response: {response}")
yield response
await prefill_output
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
......@@ -98,18 +110,15 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("triton-init").component("vllm")
await component.create_service()
decode_engine = VllmDecodeEngine(engine_args)
for i in range(decode_engine.num_prefill_workers):
prefill = (
await runtime.namespace("triton-init")
.component("prefill")
.endpoint(f"generate_kv_rank_{i}")
.client()
)
decode_engine.add_prefill(prefill)
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(decode_engine.generate)
prefill = (
await runtime.namespace("triton-init")
.component("prefill")
.endpoint("generate")
.client()
)
async with VllmDecodeEngine(engine_args, prefill) as decode_engine:
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(decode_engine.generate)
if __name__ == "__main__":
......
......@@ -35,18 +35,30 @@ class VllmPrefillEngine(BaseVllmEngine):
assert (
engine_args.kv_transfer_config.is_kv_producer
), "Prefill worker must be a KV producer"
if engine_args.enable_chunked_prefill is not False:
vllm_logger.info(
"Chunked prefill is not supported in disaggregated mode, disabling it"
)
engine_args.enable_chunked_prefill = False
super().__init__(engine_args)
self.kv_rank = self.engine.engine.vllm_config.kv_transfer_config.kv_rank
self.kv_transfer_config = engine_args.create_engine_config().kv_transfer_config
self.kv_rank = self.kv_transfer_config.kv_rank
@triton_endpoint(PrefillRequest, PrefillResponse)
async def generate(self, request):
vllm_logger.info(f"Received prefill request: {request}")
if self.engine_client is None:
await self.initialize()
vllm_logger.debug(f"Received prefill request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params)
async for response in self.engine.generate(
request.prompt, sampling_params, request.request_id
):
vllm_logger.debug(f"Generated response: {response}")
yield True
if self.engine_client is None:
raise RuntimeError("Engine client not initialized")
else:
async for response in self.engine_client.generate(
request.prompt, sampling_params, request.request_id
):
vllm_logger.debug(f"Generated response: {response}")
yield True
@triton_worker()
......@@ -58,9 +70,9 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("triton-init").component("prefill")
await component.create_service()
prefill_engine = VllmPrefillEngine(engine_args)
endpoint = component.endpoint(f"generate_kv_rank_{prefill_engine.kv_rank}")
await endpoint.serve_endpoint(prefill_engine.generate)
async with VllmPrefillEngine(engine_args) as prefill_engine:
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(prefill_engine.generate)
if __name__ == "__main__":
......
......@@ -39,6 +39,9 @@ class VllmEngine(BaseVllmEngine):
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, raw_request):
if self.engine_client is None:
await self.initialize()
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
......@@ -52,7 +55,12 @@ class VllmEngine(BaseVllmEngine):
vllm_logger.debug(
f"Running generate with engine_prompt: {engine_prompt}, sampling_params: {sampling_params}, request_id: {request_id}"
)
generator = self.engine.generate(engine_prompt, sampling_params, request_id)
if self.engine_client is None:
raise RuntimeError("Engine client not initialized")
else:
generator = self.engine_client.generate(
engine_prompt, sampling_params, request_id
)
async for response in await self._stream_response(
request, generator, request_id, conversation
......@@ -71,7 +79,9 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(VllmEngine(engine_args).generate)
async with VllmEngine(engine_args) as engine:
await endpoint.serve_endpoint(engine.generate)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment