Commit b4281383 authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

fix: update vLLM patch to 0aa204 (#92)

parent e0571935
......@@ -810,10 +810,10 @@ index 00000000..9b938039
\ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
index 00000000..f1459cf9
index 00000000..9b757396
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,404 @@
@@ -0,0 +1,400 @@
+import torch
+from typing import List, Tuple
+from vllm.config import VllmConfig
......@@ -896,12 +896,12 @@ index 00000000..f1459cf9
+ for key_cache, value_cache in kv_caches:
+ base_addr = key_cache.data_ptr()
+ region_len = 2 * num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank))
+ caches_data.append((base_addr, region_len, self.rank, ""))
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
......@@ -952,7 +952,7 @@ index 00000000..f1459cf9
+ start_offset = range_start * block_len
+ blocks_data.append((key_base_addr + start_offset, range_len * block_len, rank))
+ blocks_data.append((value_base_addr + start_offset, range_len * block_len, rank))
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ return self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+
+ def _get_ranges(self, block_ids):
+ # This function should return a list of ranges of block ids that are contiguous
......@@ -1147,16 +1147,12 @@ index 00000000..f1459cf9
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+
+ def get_notifs(self):
+ return self.nixl_wrapper.update_notifs()
+
+ def get_new_notifs(self):
+ return self.nixl_wrapper.get_new_notifs()
+
+
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
+ self._tp_size[engine_id] = agent_tp
+ agent_names = []
......@@ -1182,8 +1178,8 @@ index 00000000..f1459cf9
+ tp_multiplier_offset = i * dst_block_len
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side(descs)
+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side("", descs)
+
+ # create dst xfer side handles
+ self.dst_num_blocks[engine_id] = num_blocks
......@@ -1195,8 +1191,8 @@ index 00000000..f1459cf9
+ block_offset = block_id * dst_block_len
+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(descs, remote_agent=self._remote_agents[engine_id][self.rank * tp_multiplier + i])
+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs)
+
+ return agent_names
+
......@@ -2923,7 +2919,7 @@ index d82d9ad9..931784f8 100644
def _has_remaining_steps(
diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py
index 3cf1850e..6b90ece7 100644
index 3cf1850e..ae006579 100644
--- a/vllm/engine/multiprocessing/__init__.py
+++ b/vllm/engine/multiprocessing/__init__.py
@@ -14,13 +14,17 @@ from vllm.outputs import RequestOutput
......@@ -2979,7 +2975,7 @@ index 3cf1850e..6b90ece7 100644
class RPCUProfileRequest(Enum):
START_PROFILE = 1
@@ -157,3 +163,10 @@ def ENGINE_DEAD_ERROR(
@@ -157,3 +163,13 @@ def ENGINE_DEAD_ERROR(
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")
......@@ -2990,8 +2986,11 @@ index 3cf1850e..6b90ece7 100644
+ request_total_slots: int
+ kv_active_blocks: int
+ kv_total_blocks: int
+ num_requests_waiting: int
+ gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..da207947 100644
index 85b5f31e..fe719642 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
......@@ -3002,7 +3001,14 @@ index 85b5f31e..da207947 100644
import psutil
import zmq
import zmq.asyncio
@@ -25,14 +26,16 @@ from vllm.engine.async_llm_engine import (
@@ -19,20 +20,23 @@ from vllm import PoolingParams
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.engine.metrics import Stats
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
......@@ -3022,7 +3028,7 @@ index 85b5f31e..da207947 100644
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
@@ -46,6 +49,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -46,6 +50,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
......@@ -3031,7 +3037,7 @@ index 85b5f31e..da207947 100644
logger = init_logger(__name__)
@@ -91,6 +96,7 @@ class MQLLMEngineClient(EngineClient):
@@ -91,6 +97,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None
# Get the configs.
......@@ -3039,7 +3045,7 @@ index 85b5f31e..da207947 100644
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
@@ -115,6 +121,10 @@ class MQLLMEngineClient(EngineClient):
@@ -115,6 +122,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
......@@ -3050,7 +3056,7 @@ index 85b5f31e..da207947 100644
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -129,8 +139,27 @@ class MQLLMEngineClient(EngineClient):
@@ -129,8 +140,27 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
......@@ -3078,7 +3084,7 @@ index 85b5f31e..da207947 100644
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported
@@ -180,6 +209,56 @@ class MQLLMEngineClient(EngineClient):
@@ -180,6 +210,63 @@ class MQLLMEngineClient(EngineClient):
except Exception as e:
self._set_errored(e)
......@@ -3111,15 +3117,22 @@ index 85b5f31e..da207947 100644
+ if await self.metrics_socket.poll(timeout=timeout):
+ # Metrics received- check the message
+ message: Frame = await self.metrics_socket.recv(copy=False)
+ kv_metrics = pickle.loads(message.buffer)
+ metrics = pickle.loads(message.buffer)
+ if self.metrics_publisher is not None:
+ if isinstance(kv_metrics, KvMetrics):
+ self.metrics_publisher.publish(kv_metrics.request_active_slots,
+ kv_metrics.request_total_slots,
+ kv_metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks)
+ if isinstance(metrics, KvMetrics):
+ self.metrics_publisher.publish(metrics.request_active_slots,
+ metrics.request_total_slots,
+ metrics.kv_active_blocks,
+ metrics.kv_total_blocks,
+ metrics.num_requests_waiting,
+ metrics.gpu_cache_usage_perc,
+ metrics.gpu_prefix_cache_hit_rate)
+ if isinstance(metrics, Stats):
+ # TODO
+ # Send the whole stats to user
+ pass
+
+ logger.debug("Metircs successful.")
+ logger.debug("Metrics successful.")
+
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
......@@ -3135,7 +3148,7 @@ index 85b5f31e..da207947 100644
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
@@ -278,12 +357,26 @@ class MQLLMEngineClient(EngineClient):
@@ -278,12 +365,26 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
......@@ -3162,7 +3175,7 @@ index 85b5f31e..da207947 100644
def close(self):
"""Destroy the ZeroMQ Context."""
@@ -293,6 +386,8 @@ class MQLLMEngineClient(EngineClient):
@@ -293,6 +394,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
......@@ -3171,7 +3184,7 @@ index 85b5f31e..da207947 100644
if self.output_loop is not None:
self.output_loop.cancel()
@@ -415,6 +510,9 @@ class MQLLMEngineClient(EngineClient):
@@ -415,6 +518,9 @@ class MQLLMEngineClient(EngineClient):
"""
if self._errored_with is not None:
raise self._errored_with
......@@ -3181,7 +3194,7 @@ index 85b5f31e..da207947 100644
@property
def is_running(self) -> bool:
@@ -473,6 +571,7 @@ class MQLLMEngineClient(EngineClient):
@@ -473,6 +579,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -3189,7 +3202,7 @@ index 85b5f31e..da207947 100644
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
@@ -502,7 +601,8 @@ class MQLLMEngineClient(EngineClient):
@@ -502,7 +609,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
......@@ -3199,7 +3212,7 @@ index 85b5f31e..da207947 100644
@overload
def encode(
@@ -586,6 +686,7 @@ class MQLLMEngineClient(EngineClient):
@@ -586,6 +694,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -3207,7 +3220,7 @@ index 85b5f31e..da207947 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -630,6 +731,12 @@ class MQLLMEngineClient(EngineClient):
@@ -630,6 +739,12 @@ class MQLLMEngineClient(EngineClient):
else:
lp_bytes = None
......@@ -3220,7 +3233,7 @@ index 85b5f31e..da207947 100644
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
@@ -639,11 +746,11 @@ class MQLLMEngineClient(EngineClient):
@@ -639,11 +754,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
......@@ -3234,7 +3247,7 @@ index 85b5f31e..da207947 100644
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
@@ -705,3 +812,6 @@ class MQLLMEngineClient(EngineClient):
@@ -705,3 +820,6 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
......@@ -3242,10 +3255,10 @@ index 85b5f31e..da207947 100644
+ def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..dbd9d58d 100644
index a0dd7958..3204cfb8 100644
--- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py
@@ -3,35 +3,73 @@
@@ -3,35 +3,115 @@
import pickle
import signal
from contextlib import contextmanager
......@@ -3305,26 +3318,68 @@ index a0dd7958..dbd9d58d 100644
+ self.metrics_socket = metrics_socket
+
+ # KV metrics
+ self._send_kv_metrics(0, 0)
+ self._send_kv_metrics(0, 0, 0, 0.0, 0.0)
+
+ def log(self, stats: Stats) -> None:
+ self._send_kv_metrics(
+ stats.num_running_sys,
+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks)
+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks),
+ stats.num_waiting_sys,
+ stats.gpu_cache_usage_sys,
+ stats.gpu_prefix_cache_hit_rate
+ )
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+
+ def _send_kv_metrics(
+ self,
+ active_slots,
+ active_kv_blocks,
+ num_requests_waiting,
+ gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate,
+ ):
+ if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(
+ KvMetrics(
+ active_slots,
+ self.request_total_slots,
+ active_kv_blocks,
+ self.kv_total_blocks,
+ num_requests_waiting,
+ gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate,
+ )
+ )
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+# TODO: Send entire stats object to the client
+class StatLogger(StatLoggerBase):
+ def __init__(
+ self,
+ metrics_socket
+ ):
+ self.metrics_socket = metrics_socket
+
+ def log(self, stats: Stats) -> None:
+ self._send_metrics(stats)
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+
+ def _send_kv_metrics(self, active_slots, active_kv_blocks):
+ def _send_metrics(self, stats: Stats):
+ if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(KvMetrics(active_slots, self.request_total_slots, active_kv_blocks, self.kv_total_blocks))
+ metrics_bytes = pickle.dumps(stats)
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+
+
+
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +132,31 @@ class MQLLMEngine:
@@ -94,12 +174,35 @@ class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
......@@ -3346,17 +3401,21 @@ index a0dd7958..dbd9d58d 100644
+
+
+ # Attach logger for continuous metrics publishing
+ self.stat_logger = KvStatLogger(
+ self.kv_stat_logger = KvStatLogger(
+ self.engine.scheduler_config.max_num_seqs,
+ self.engine.cache_config.num_gpu_blocks,
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.stat_logger)
+ self.general_stat_logger = StatLogger(
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.kv_stat_logger)
+ self.engine.add_logger("general_metrics", self.general_stat_logger)
+
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
@@ -171,8 +228,17 @@ class MQLLMEngine:
@@ -171,8 +274,17 @@ class MQLLMEngine:
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
......@@ -3376,7 +3435,7 @@ index a0dd7958..dbd9d58d 100644
except Exception as e:
response = e
@@ -185,6 +251,7 @@ class MQLLMEngine:
@@ -185,6 +297,7 @@ class MQLLMEngine:
while True:
if not self.engine.has_unfinished_requests():
......@@ -3384,7 +3443,7 @@ index a0dd7958..dbd9d58d 100644
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
@@ -220,6 +287,13 @@ class MQLLMEngine:
@@ -220,6 +333,13 @@ class MQLLMEngine:
def handle_new_input(self):
"""Handle new input from the socket"""
try:
......@@ -3398,7 +3457,7 @@ index a0dd7958..dbd9d58d 100644
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
@@ -262,6 +336,11 @@ class MQLLMEngine:
@@ -262,6 +382,11 @@ class MQLLMEngine:
self._send_outputs(rpc_err)
try:
......@@ -3410,7 +3469,7 @@ index a0dd7958..dbd9d58d 100644
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
@@ -269,7 +348,9 @@ class MQLLMEngine:
@@ -269,7 +394,9 @@ class MQLLMEngine:
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment