Commit b4281383 authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

fix: update vLLM patch to 0aa204 (#92)

parent e0571935
...@@ -810,10 +810,10 @@ index 00000000..9b938039 ...@@ -810,10 +810,10 @@ index 00000000..9b938039
\ No newline at end of file \ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644 new file mode 100644
index 00000000..f1459cf9 index 00000000..9b757396
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py +++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,404 @@ @@ -0,0 +1,400 @@
+import torch +import torch
+from typing import List, Tuple +from typing import List, Tuple
+from vllm.config import VllmConfig +from vllm.config import VllmConfig
...@@ -896,12 +896,12 @@ index 00000000..f1459cf9 ...@@ -896,12 +896,12 @@ index 00000000..f1459cf9
+ for key_cache, value_cache in kv_caches: + for key_cache, value_cache in kv_caches:
+ base_addr = key_cache.data_ptr() + base_addr = key_cache.data_ptr()
+ region_len = 2 * num_blocks * self.block_len + region_len = 2 * num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank)) + caches_data.append((base_addr, region_len, self.rank, ""))
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) + kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ +
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+ +
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data)) + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
+ logger.debug("Registering descs: %s", caches_data) + logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs) + self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs) + self._registered_descs.append(descs)
...@@ -952,7 +952,7 @@ index 00000000..f1459cf9 ...@@ -952,7 +952,7 @@ index 00000000..f1459cf9
+ start_offset = range_start * block_len + start_offset = range_start * block_len
+ blocks_data.append((key_base_addr + start_offset, range_len * block_len, rank)) + blocks_data.append((key_base_addr + start_offset, range_len * block_len, rank))
+ blocks_data.append((value_base_addr + start_offset, range_len * block_len, rank)) + blocks_data.append((value_base_addr + start_offset, range_len * block_len, rank))
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data)) + return self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+ +
+ def _get_ranges(self, block_ids): + def _get_ranges(self, block_ids):
+ # This function should return a list of ranges of block ids that are contiguous + # This function should return a list of ranges of block ids that are contiguous
...@@ -1146,16 +1146,12 @@ index 00000000..f1459cf9 ...@@ -1146,16 +1146,12 @@ index 00000000..f1459cf9
+ status = self.nixl_wrapper.transfer(handle) + status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000) + logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status) + logger.debug("Transfer status: %s", status)
+ +
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+
+ def get_notifs(self): + def get_notifs(self):
+ return self.nixl_wrapper.update_notifs() + return self.nixl_wrapper.update_notifs()
+ +
+ def get_new_notifs(self): + def get_new_notifs(self):
+ return self.nixl_wrapper.get_new_notifs() + return self.nixl_wrapper.get_new_notifs()
+
+ +
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks): + def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
+ self._tp_size[engine_id] = agent_tp + self._tp_size[engine_id] = agent_tp
...@@ -1182,8 +1178,8 @@ index 00000000..f1459cf9 ...@@ -1182,8 +1178,8 @@ index 00000000..f1459cf9
+ tp_multiplier_offset = i * dst_block_len + tp_multiplier_offset = i * dst_block_len
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) + blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i) + logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(("VRAM", blocks_data)) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side(descs) + self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side("", descs)
+ +
+ # create dst xfer side handles + # create dst xfer side handles
+ self.dst_num_blocks[engine_id] = num_blocks + self.dst_num_blocks[engine_id] = num_blocks
...@@ -1195,8 +1191,8 @@ index 00000000..f1459cf9 ...@@ -1195,8 +1191,8 @@ index 00000000..f1459cf9
+ block_offset = block_id * dst_block_len + block_offset = block_id * dst_block_len
+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i)) + blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i) + logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(("VRAM", blocks_data)) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(descs, remote_agent=self._remote_agents[engine_id][self.rank * tp_multiplier + i]) + self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs)
+ +
+ return agent_names + return agent_names
+ +
...@@ -2923,7 +2919,7 @@ index d82d9ad9..931784f8 100644 ...@@ -2923,7 +2919,7 @@ index d82d9ad9..931784f8 100644
def _has_remaining_steps( def _has_remaining_steps(
diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py
index 3cf1850e..6b90ece7 100644 index 3cf1850e..ae006579 100644
--- a/vllm/engine/multiprocessing/__init__.py --- a/vllm/engine/multiprocessing/__init__.py
+++ b/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py
@@ -14,13 +14,17 @@ from vllm.outputs import RequestOutput @@ -14,13 +14,17 @@ from vllm.outputs import RequestOutput
...@@ -2979,7 +2975,7 @@ index 3cf1850e..6b90ece7 100644 ...@@ -2979,7 +2975,7 @@ index 3cf1850e..6b90ece7 100644
class RPCUProfileRequest(Enum): class RPCUProfileRequest(Enum):
START_PROFILE = 1 START_PROFILE = 1
@@ -157,3 +163,10 @@ def ENGINE_DEAD_ERROR( @@ -157,3 +163,13 @@ def ENGINE_DEAD_ERROR(
return MQEngineDeadError( return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to " "Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.") f"find the original error: {repr(error)}.")
...@@ -2990,8 +2986,11 @@ index 3cf1850e..6b90ece7 100644 ...@@ -2990,8 +2986,11 @@ index 3cf1850e..6b90ece7 100644
+ request_total_slots: int + request_total_slots: int
+ kv_active_blocks: int + kv_active_blocks: int
+ kv_total_blocks: int + kv_total_blocks: int
+ num_requests_waiting: int
+ gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..da207947 100644 index 85b5f31e..fe719642 100644
--- a/vllm/engine/multiprocessing/client.py --- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, @@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...@@ -3002,7 +3001,14 @@ index 85b5f31e..da207947 100644 ...@@ -3002,7 +3001,14 @@ index 85b5f31e..da207947 100644
import psutil import psutil
import zmq import zmq
import zmq.asyncio import zmq.asyncio
@@ -25,14 +26,16 @@ from vllm.engine.async_llm_engine import ( @@ -19,20 +20,23 @@ from vllm import PoolingParams
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.engine.metrics import Stats
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async) build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
...@@ -3022,7 +3028,7 @@ index 85b5f31e..da207947 100644 ...@@ -3022,7 +3028,7 @@ index 85b5f31e..da207947 100644
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
@@ -46,6 +49,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest @@ -46,6 +50,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs from vllm.utils import deprecate_kwargs
...@@ -3031,7 +3037,7 @@ index 85b5f31e..da207947 100644 ...@@ -3031,7 +3037,7 @@ index 85b5f31e..da207947 100644
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -91,6 +96,7 @@ class MQLLMEngineClient(EngineClient): @@ -91,6 +97,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Get the configs. # Get the configs.
...@@ -3039,7 +3045,7 @@ index 85b5f31e..da207947 100644 ...@@ -3039,7 +3045,7 @@ index 85b5f31e..da207947 100644
self.model_config = engine_config.model_config self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config self.decoding_config = engine_config.decoding_config
@@ -115,6 +121,10 @@ class MQLLMEngineClient(EngineClient): @@ -115,6 +122,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
...@@ -3050,7 +3056,7 @@ index 85b5f31e..da207947 100644 ...@@ -3050,7 +3056,7 @@ index 85b5f31e..da207947 100644
# IPC path for the data socket. # IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -129,8 +139,27 @@ class MQLLMEngineClient(EngineClient): @@ -129,8 +140,27 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically. # Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready. # Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None self.health_loop: Optional[asyncio.Task] = None
...@@ -3078,7 +3084,7 @@ index 85b5f31e..da207947 100644 ...@@ -3078,7 +3084,7 @@ index 85b5f31e..da207947 100644
@staticmethod @staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs): def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported # Pipeline parallel not yet supported
@@ -180,6 +209,56 @@ class MQLLMEngineClient(EngineClient): @@ -180,6 +210,63 @@ class MQLLMEngineClient(EngineClient):
except Exception as e: except Exception as e:
self._set_errored(e) self._set_errored(e)
...@@ -3111,15 +3117,22 @@ index 85b5f31e..da207947 100644 ...@@ -3111,15 +3117,22 @@ index 85b5f31e..da207947 100644
+ if await self.metrics_socket.poll(timeout=timeout): + if await self.metrics_socket.poll(timeout=timeout):
+ # Metrics received- check the message + # Metrics received- check the message
+ message: Frame = await self.metrics_socket.recv(copy=False) + message: Frame = await self.metrics_socket.recv(copy=False)
+ kv_metrics = pickle.loads(message.buffer) + metrics = pickle.loads(message.buffer)
+ if self.metrics_publisher is not None: + if self.metrics_publisher is not None:
+ if isinstance(kv_metrics, KvMetrics): + if isinstance(metrics, KvMetrics):
+ self.metrics_publisher.publish(kv_metrics.request_active_slots, + self.metrics_publisher.publish(metrics.request_active_slots,
+ kv_metrics.request_total_slots, + metrics.request_total_slots,
+ kv_metrics.kv_active_blocks, + metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks) + metrics.kv_total_blocks,
+ + metrics.num_requests_waiting,
+ logger.debug("Metircs successful.") + metrics.gpu_cache_usage_perc,
+ metrics.gpu_prefix_cache_hit_rate)
+ if isinstance(metrics, Stats):
+ # TODO
+ # Send the whole stats to user
+ pass
+
+ logger.debug("Metrics successful.")
+ +
+ except asyncio.CancelledError: + except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.") + logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
...@@ -3135,7 +3148,7 @@ index 85b5f31e..da207947 100644 ...@@ -3135,7 +3148,7 @@ index 85b5f31e..da207947 100644
async def run_output_handler_loop(self): async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues""" """Get RequestOutputs from Engine and stream to Request Queues"""
@@ -278,12 +357,26 @@ class MQLLMEngineClient(EngineClient): @@ -278,12 +365,26 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready. # Wait until server is ready.
response = await self._wait_for_server_rpc(socket) response = await self._wait_for_server_rpc(socket)
...@@ -3162,7 +3175,7 @@ index 85b5f31e..da207947 100644 ...@@ -3162,7 +3175,7 @@ index 85b5f31e..da207947 100644
def close(self): def close(self):
"""Destroy the ZeroMQ Context.""" """Destroy the ZeroMQ Context."""
@@ -293,6 +386,8 @@ class MQLLMEngineClient(EngineClient): @@ -293,6 +394,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks. # Cancel background tasks.
if self.health_loop is not None: if self.health_loop is not None:
self.health_loop.cancel() self.health_loop.cancel()
...@@ -3171,7 +3184,7 @@ index 85b5f31e..da207947 100644 ...@@ -3171,7 +3184,7 @@ index 85b5f31e..da207947 100644
if self.output_loop is not None: if self.output_loop is not None:
self.output_loop.cancel() self.output_loop.cancel()
@@ -415,6 +510,9 @@ class MQLLMEngineClient(EngineClient): @@ -415,6 +518,9 @@ class MQLLMEngineClient(EngineClient):
""" """
if self._errored_with is not None: if self._errored_with is not None:
raise self._errored_with raise self._errored_with
...@@ -3181,7 +3194,7 @@ index 85b5f31e..da207947 100644 ...@@ -3181,7 +3194,7 @@ index 85b5f31e..da207947 100644
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@@ -473,6 +571,7 @@ class MQLLMEngineClient(EngineClient): @@ -473,6 +579,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -3189,7 +3202,7 @@ index 85b5f31e..da207947 100644 ...@@ -3189,7 +3202,7 @@ index 85b5f31e..da207947 100644
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
@@ -502,7 +601,8 @@ class MQLLMEngineClient(EngineClient): @@ -502,7 +609,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id, return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
...@@ -3199,7 +3212,7 @@ index 85b5f31e..da207947 100644 ...@@ -3199,7 +3212,7 @@ index 85b5f31e..da207947 100644
@overload @overload
def encode( def encode(
@@ -586,6 +686,7 @@ class MQLLMEngineClient(EngineClient): @@ -586,6 +694,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -3207,7 +3220,7 @@ index 85b5f31e..da207947 100644 ...@@ -3207,7 +3220,7 @@ index 85b5f31e..da207947 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]: PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -630,6 +731,12 @@ class MQLLMEngineClient(EngineClient): @@ -630,6 +739,12 @@ class MQLLMEngineClient(EngineClient):
else: else:
lp_bytes = None lp_bytes = None
...@@ -3220,7 +3233,7 @@ index 85b5f31e..da207947 100644 ...@@ -3220,7 +3233,7 @@ index 85b5f31e..da207947 100644
request_bytes = pickle.dumps( request_bytes = pickle.dumps(
RPCProcessRequest( RPCProcessRequest(
prompt=prompt, prompt=prompt,
@@ -639,11 +746,11 @@ class MQLLMEngineClient(EngineClient): @@ -639,11 +754,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority, priority=priority,
...@@ -3234,7 +3247,7 @@ index 85b5f31e..da207947 100644 ...@@ -3234,7 +3247,7 @@ index 85b5f31e..da207947 100644
await self.input_socket.send_multipart(parts, copy=False) await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note # 4) Stream the RequestOutputs from the output queue. Note
@@ -705,3 +812,6 @@ class MQLLMEngineClient(EngineClient): @@ -705,3 +820,6 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None # Raise on error, otherwise happily return None
if isinstance(request_output, BaseException): if isinstance(request_output, BaseException):
raise request_output raise request_output
...@@ -3242,10 +3255,10 @@ index 85b5f31e..da207947 100644 ...@@ -3242,10 +3255,10 @@ index 85b5f31e..da207947 100644
+ def set_metrics_publisher(self, metrics_publisher): + def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher + self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..dbd9d58d 100644 index a0dd7958..3204cfb8 100644
--- a/vllm/engine/multiprocessing/engine.py --- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py
@@ -3,35 +3,73 @@ @@ -3,35 +3,115 @@
import pickle import pickle
import signal import signal
from contextlib import contextmanager from contextlib import contextmanager
...@@ -3305,26 +3318,68 @@ index a0dd7958..dbd9d58d 100644 ...@@ -3305,26 +3318,68 @@ index a0dd7958..dbd9d58d 100644
+ self.metrics_socket = metrics_socket + self.metrics_socket = metrics_socket
+ +
+ # KV metrics + # KV metrics
+ self._send_kv_metrics(0, 0) + self._send_kv_metrics(0, 0, 0, 0.0, 0.0)
+ +
+ def log(self, stats: Stats) -> None: + def log(self, stats: Stats) -> None:
+ self._send_kv_metrics( + self._send_kv_metrics(
+ stats.num_running_sys, + stats.num_running_sys,
+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks) + int(stats.gpu_cache_usage_sys * self.kv_total_blocks),
+ stats.num_waiting_sys,
+ stats.gpu_cache_usage_sys,
+ stats.gpu_prefix_cache_hit_rate
+ ) + )
+ +
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None: + def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass + pass
+ +
+ def _send_kv_metrics(self, active_slots, active_kv_blocks): + def _send_kv_metrics(
+ self,
+ active_slots,
+ active_kv_blocks,
+ num_requests_waiting,
+ gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate,
+ ):
+ if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(
+ KvMetrics(
+ active_slots,
+ self.request_total_slots,
+ active_kv_blocks,
+ self.kv_total_blocks,
+ num_requests_waiting,
+ gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate,
+ )
+ )
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+# TODO: Send entire stats object to the client
+class StatLogger(StatLoggerBase):
+ def __init__(
+ self,
+ metrics_socket
+ ):
+ self.metrics_socket = metrics_socket
+
+ def log(self, stats: Stats) -> None:
+ self._send_metrics(stats)
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+
+ def _send_metrics(self, stats: Stats):
+ if not self.metrics_socket.closed: + if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(KvMetrics(active_slots, self.request_total_slots, active_kv_blocks, self.kv_total_blocks)) + metrics_bytes = pickle.dumps(stats)
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) + self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+
+
+ +
class MQLLMEngine: class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`. """A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +132,31 @@ class MQLLMEngine: @@ -94,12 +174,35 @@ class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
...@@ -3346,17 +3401,21 @@ index a0dd7958..dbd9d58d 100644 ...@@ -3346,17 +3401,21 @@ index a0dd7958..dbd9d58d 100644
+ +
+ +
+ # Attach logger for continuous metrics publishing + # Attach logger for continuous metrics publishing
+ self.stat_logger = KvStatLogger( + self.kv_stat_logger = KvStatLogger(
+ self.engine.scheduler_config.max_num_seqs, + self.engine.scheduler_config.max_num_seqs,
+ self.engine.cache_config.num_gpu_blocks, + self.engine.cache_config.num_gpu_blocks,
+ self.metrics_socket + self.metrics_socket
+ ) + )
+ self.engine.add_logger("kv_metrics", self.stat_logger) + self.general_stat_logger = StatLogger(
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.kv_stat_logger)
+ self.engine.add_logger("general_metrics", self.general_stat_logger)
+ +
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
if self._errored_with is not None: if self._errored_with is not None:
@@ -171,8 +228,17 @@ class MQLLMEngine: @@ -171,8 +274,17 @@ class MQLLMEngine:
# Handle the query from the Client. # Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY: if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled() tracing_enabled = self.engine.is_tracing_enabled()
...@@ -3376,7 +3435,7 @@ index a0dd7958..dbd9d58d 100644 ...@@ -3376,7 +3435,7 @@ index a0dd7958..dbd9d58d 100644
except Exception as e: except Exception as e:
response = e response = e
@@ -185,6 +251,7 @@ class MQLLMEngine: @@ -185,6 +297,7 @@ class MQLLMEngine:
while True: while True:
if not self.engine.has_unfinished_requests(): if not self.engine.has_unfinished_requests():
...@@ -3384,7 +3443,7 @@ index a0dd7958..dbd9d58d 100644 ...@@ -3384,7 +3443,7 @@ index a0dd7958..dbd9d58d 100644
# Poll until there is work to do. # Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send # When there's no work, check on engine health and send
@@ -220,6 +287,13 @@ class MQLLMEngine: @@ -220,6 +333,13 @@ class MQLLMEngine:
def handle_new_input(self): def handle_new_input(self):
"""Handle new input from the socket""" """Handle new input from the socket"""
try: try:
...@@ -3398,7 +3457,7 @@ index a0dd7958..dbd9d58d 100644 ...@@ -3398,7 +3457,7 @@ index a0dd7958..dbd9d58d 100644
while self.input_socket.poll(timeout=0) != 0: while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False) frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer) request = pickle.loads(frames[0].buffer)
@@ -262,6 +336,11 @@ class MQLLMEngine: @@ -262,6 +382,11 @@ class MQLLMEngine:
self._send_outputs(rpc_err) self._send_outputs(rpc_err)
try: try:
...@@ -3410,7 +3469,7 @@ index a0dd7958..dbd9d58d 100644 ...@@ -3410,7 +3469,7 @@ index a0dd7958..dbd9d58d 100644
self.engine.add_request( self.engine.add_request(
request_id=request_id, request_id=request_id,
prompt=request.prompt, prompt=request.prompt,
@@ -269,7 +348,9 @@ class MQLLMEngine: @@ -269,7 +394,9 @@ class MQLLMEngine:
lora_request=request.lora_request, lora_request=request.lora_request,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request, prompt_adapter_request=request.prompt_adapter_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment