Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b4281383
Commit
b4281383
authored
Mar 11, 2025
by
ptarasiewiczNV
Committed by
GitHub
Mar 11, 2025
Browse files
fix: update vLLM patch to 0aa204 (#92)
parent
e0571935
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
54 deletions
+113
-54
container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch
container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch
+113
-54
No files found.
container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch
View file @
b4281383
...
@@ -810,10 +810,10 @@ index 00000000..9b938039
...
@@ -810,10 +810,10 @@ index 00000000..9b938039
\
No newline at end of file
\
No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
new file mode 100644
index 00000000..
f1459cf9
index 00000000..
9b757396
--- /dev/null
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,40
4
@@
@@ -0,0 +1,40
0
@@
+import torch
+import torch
+from typing import List, Tuple
+from typing import List, Tuple
+from vllm.config import VllmConfig
+from vllm.config import VllmConfig
...
@@ -896,12 +896,12 @@ index 00000000..f1459cf9
...
@@ -896,12 +896,12 @@ index 00000000..f1459cf9
+ for key_cache, value_cache in kv_caches:
+ for key_cache, value_cache in kv_caches:
+ base_addr = key_cache.data_ptr()
+ base_addr = key_cache.data_ptr()
+ region_len = 2 * num_blocks * self.block_len
+ region_len = 2 * num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank))
+ caches_data.append((base_addr, region_len, self.rank
, ""
))
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+
+
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+
+ descs = self.nixl_wrapper.get_descs(
("VRAM",
caches_data
)
)
+ descs = self.nixl_wrapper.get_
reg_
descs(caches_data
, "VRAM"
)
+ logger.debug("Registering descs: %s", caches_data)
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+ self._registered_descs.append(descs)
...
@@ -952,7 +952,7 @@ index 00000000..f1459cf9
...
@@ -952,7 +952,7 @@ index 00000000..f1459cf9
+ start_offset = range_start * block_len
+ start_offset = range_start * block_len
+ blocks_data.append((key_base_addr + start_offset, range_len * block_len, rank))
+ blocks_data.append((key_base_addr + start_offset, range_len * block_len, rank))
+ blocks_data.append((value_base_addr + start_offset, range_len * block_len, rank))
+ blocks_data.append((value_base_addr + start_offset, range_len * block_len, rank))
+ return self.nixl_wrapper.get_descs(
("VRAM",
blocks_data
)
)
+ return self.nixl_wrapper.get_
xfer_
descs(blocks_data
, "VRAM"
)
+
+
+ def _get_ranges(self, block_ids):
+ def _get_ranges(self, block_ids):
+ # This function should return a list of ranges of block ids that are contiguous
+ # This function should return a list of ranges of block ids that are contiguous
...
@@ -1146,16 +1146,12 @@ index 00000000..f1459cf9
...
@@ -1146,16 +1146,12 @@ index 00000000..f1459cf9
+ status = self.nixl_wrapper.transfer(handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ logger.debug("Transfer status: %s", status)
+
+
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+
+ def get_notifs(self):
+ def get_notifs(self):
+ return self.nixl_wrapper.update_notifs()
+ return self.nixl_wrapper.update_notifs()
+
+
+ def get_new_notifs(self):
+ def get_new_notifs(self):
+ return self.nixl_wrapper.get_new_notifs()
+ return self.nixl_wrapper.get_new_notifs()
+
+
+
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
+ self._tp_size[engine_id] = agent_tp
+ self._tp_size[engine_id] = agent_tp
...
@@ -1182,8 +1178,8 @@ index 00000000..f1459cf9
...
@@ -1182,8 +1178,8 @@ index 00000000..f1459cf9
+ tp_multiplier_offset = i * dst_block_len
+ tp_multiplier_offset = i * dst_block_len
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(
("VRAM",
blocks_data
)
)
+ descs = self.nixl_wrapper.get_
xfer_
descs(blocks_data
, "VRAM"
)
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side(descs)
+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_side(
"",
descs)
+
+
+ # create dst xfer side handles
+ # create dst xfer side handles
+ self.dst_num_blocks[engine_id] = num_blocks
+ self.dst_num_blocks[engine_id] = num_blocks
...
@@ -1195,8 +1191,8 @@ index 00000000..f1459cf9
...
@@ -1195,8 +1191,8 @@ index 00000000..f1459cf9
+ block_offset = block_id * dst_block_len
+ block_offset = block_id * dst_block_len
+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
+ descs = self.nixl_wrapper.get_descs(
("VRAM",
blocks_data
)
)
+ descs = self.nixl_wrapper.get_
xfer_
descs(blocks_data
, "VRAM"
)
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(
descs, remote_agent=
self._remote_agents[engine_id][self.rank * tp_multiplier + i])
+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_side(self._remote_agents[engine_id][self.rank * tp_multiplier + i]
, descs
)
+
+
+ return agent_names
+ return agent_names
+
+
...
@@ -2923,7 +2919,7 @@ index d82d9ad9..931784f8 100644
...
@@ -2923,7 +2919,7 @@ index d82d9ad9..931784f8 100644
def _has_remaining_steps(
def _has_remaining_steps(
diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py
diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py
index 3cf1850e..
6b90ece7
100644
index 3cf1850e..
ae006579
100644
--- a/vllm/engine/multiprocessing/__init__.py
--- a/vllm/engine/multiprocessing/__init__.py
+++ b/vllm/engine/multiprocessing/__init__.py
+++ b/vllm/engine/multiprocessing/__init__.py
@@ -14,13 +14,17 @@
from vllm.outputs import RequestOutput
@@ -14,13 +14,17 @@
from vllm.outputs import RequestOutput
...
@@ -2979,7 +2975,7 @@ index 3cf1850e..6b90ece7 100644
...
@@ -2979,7 +2975,7 @@ index 3cf1850e..6b90ece7 100644
class RPCUProfileRequest(Enum):
class RPCUProfileRequest(Enum):
START_PROFILE = 1
START_PROFILE = 1
@@ -157,3 +163,1
0
@@
def ENGINE_DEAD_ERROR(
@@ -157,3 +163,1
3
@@
def ENGINE_DEAD_ERROR(
return MQEngineDeadError(
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")
f"find the original error: {repr(error)}.")
...
@@ -2990,8 +2986,11 @@ index 3cf1850e..6b90ece7 100644
...
@@ -2990,8 +2986,11 @@ index 3cf1850e..6b90ece7 100644
+ request_total_slots: int
+ request_total_slots: int
+ kv_active_blocks: int
+ kv_active_blocks: int
+ kv_total_blocks: int
+ kv_total_blocks: int
+ num_requests_waiting: int
+ gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..
da207947
100644
index 85b5f31e..
fe719642
100644
--- a/vllm/engine/multiprocessing/client.py
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
@@ -8,6 +8,7 @@
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...
@@ -3002,7 +3001,14 @@ index 85b5f31e..da207947 100644
...
@@ -3002,7 +3001,14 @@ index 85b5f31e..da207947 100644
import psutil
import psutil
import zmq
import zmq
import zmq.asyncio
import zmq.asyncio
@@ -25,14 +26,16 @@
from vllm.engine.async_llm_engine import (
@@ -19,20 +20,23 @@
from vllm import PoolingParams
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.engine.metrics import Stats
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
...
@@ -3022,7 +3028,7 @@ index 85b5f31e..da207947 100644
...
@@ -3022,7 +3028,7 @@ index 85b5f31e..da207947 100644
from vllm.engine.protocol import EngineClient
from vllm.engine.protocol import EngineClient
# yapf: enable
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.envs import VLLM_RPC_TIMEOUT
@@ -46,6 +
49
,8 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -46,6 +
50
,8 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
from vllm.utils import deprecate_kwargs
...
@@ -3031,7 +3037,7 @@ index 85b5f31e..da207947 100644
...
@@ -3031,7 +3037,7 @@ index 85b5f31e..da207947 100644
logger = init_logger(__name__)
logger = init_logger(__name__)
@@ -91,6 +9
6
,7 @@
class MQLLMEngineClient(EngineClient):
@@ -91,6 +9
7
,7 @@
class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None
self._errored_with: Optional[BaseException] = None
# Get the configs.
# Get the configs.
...
@@ -3039,7 +3045,7 @@ index 85b5f31e..da207947 100644
...
@@ -3039,7 +3045,7 @@ index 85b5f31e..da207947 100644
self.model_config = engine_config.model_config
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
self.decoding_config = engine_config.decoding_config
@@ -115,6 +12
1
,10 @@
class MQLLMEngineClient(EngineClient):
@@ -115,6 +12
2
,10 @@
class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
...
@@ -3050,7 +3056,7 @@ index 85b5f31e..da207947 100644
...
@@ -3050,7 +3056,7 @@ index 85b5f31e..da207947 100644
# IPC path for the data socket.
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -129,8 +1
39
,27 @@
class MQLLMEngineClient(EngineClient):
@@ -129,8 +1
40
,27 @@
class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
self.health_loop: Optional[asyncio.Task] = None
...
@@ -3078,7 +3084,7 @@ index 85b5f31e..da207947 100644
...
@@ -3078,7 +3084,7 @@ index 85b5f31e..da207947 100644
@staticmethod
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported
# Pipeline parallel not yet supported
@@ -180,6 +2
09,56
@@
class MQLLMEngineClient(EngineClient):
@@ -180,6 +2
10,63
@@
class MQLLMEngineClient(EngineClient):
except Exception as e:
except Exception as e:
self._set_errored(e)
self._set_errored(e)
...
@@ -3111,15 +3117,22 @@ index 85b5f31e..da207947 100644
...
@@ -3111,15 +3117,22 @@ index 85b5f31e..da207947 100644
+ if await self.metrics_socket.poll(timeout=timeout):
+ if await self.metrics_socket.poll(timeout=timeout):
+ # Metrics received- check the message
+ # Metrics received- check the message
+ message: Frame = await self.metrics_socket.recv(copy=False)
+ message: Frame = await self.metrics_socket.recv(copy=False)
+
kv_
metrics = pickle.loads(message.buffer)
+ metrics = pickle.loads(message.buffer)
+ if self.metrics_publisher is not None:
+ if self.metrics_publisher is not None:
+ if isinstance(kv_metrics, KvMetrics):
+ if isinstance(metrics, KvMetrics):
+ self.metrics_publisher.publish(kv_metrics.request_active_slots,
+ self.metrics_publisher.publish(metrics.request_active_slots,
+ kv_metrics.request_total_slots,
+ metrics.request_total_slots,
+ kv_metrics.kv_active_blocks,
+ metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks)
+ metrics.kv_total_blocks,
+
+ metrics.num_requests_waiting,
+ logger.debug("Metircs successful.")
+ metrics.gpu_cache_usage_perc,
+ metrics.gpu_prefix_cache_hit_rate)
+ if isinstance(metrics, Stats):
+ # TODO
+ # Send the whole stats to user
+ pass
+
+ logger.debug("Metrics successful.")
+
+
+ except asyncio.CancelledError:
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
...
@@ -3135,7 +3148,7 @@ index 85b5f31e..da207947 100644
...
@@ -3135,7 +3148,7 @@ index 85b5f31e..da207947 100644
async def run_output_handler_loop(self):
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
"""Get RequestOutputs from Engine and stream to Request Queues"""
@@ -278,12 +35
7
,26 @@
class MQLLMEngineClient(EngineClient):
@@ -278,12 +3
6
5,26 @@
class MQLLMEngineClient(EngineClient):
# Wait until server is ready.
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
response = await self._wait_for_server_rpc(socket)
...
@@ -3162,7 +3175,7 @@ index 85b5f31e..da207947 100644
...
@@ -3162,7 +3175,7 @@ index 85b5f31e..da207947 100644
def close(self):
def close(self):
"""Destroy the ZeroMQ Context."""
"""Destroy the ZeroMQ Context."""
@@ -293,6 +3
86
,8 @@
class MQLLMEngineClient(EngineClient):
@@ -293,6 +3
94
,8 @@
class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
# Cancel background tasks.
if self.health_loop is not None:
if self.health_loop is not None:
self.health_loop.cancel()
self.health_loop.cancel()
...
@@ -3171,7 +3184,7 @@ index 85b5f31e..da207947 100644
...
@@ -3171,7 +3184,7 @@ index 85b5f31e..da207947 100644
if self.output_loop is not None:
if self.output_loop is not None:
self.output_loop.cancel()
self.output_loop.cancel()
@@ -415,6 +51
0
,9 @@
class MQLLMEngineClient(EngineClient):
@@ -415,6 +51
8
,9 @@
class MQLLMEngineClient(EngineClient):
"""
"""
if self._errored_with is not None:
if self._errored_with is not None:
raise self._errored_with
raise self._errored_with
...
@@ -3181,7 +3194,7 @@ index 85b5f31e..da207947 100644
...
@@ -3181,7 +3194,7 @@ index 85b5f31e..da207947 100644
@property
@property
def is_running(self) -> bool:
def is_running(self) -> bool:
@@ -473,6 +57
1
,7 @@
class MQLLMEngineClient(EngineClient):
@@ -473,6 +57
9
,7 @@
class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
priority: int = 0,
...
@@ -3189,7 +3202,7 @@ index 85b5f31e..da207947 100644
...
@@ -3189,7 +3202,7 @@ index 85b5f31e..da207947 100644
*,
*,
inputs: Optional[PromptType] = None # DEPRECATED
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
) -> AsyncGenerator[RequestOutput, None]:
@@ -502,7 +60
1
,8 @@
class MQLLMEngineClient(EngineClient):
@@ -502,7 +60
9
,8 @@
class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id,
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
lora_request, trace_headers,
...
@@ -3199,7 +3212,7 @@ index 85b5f31e..da207947 100644
...
@@ -3199,7 +3212,7 @@ index 85b5f31e..da207947 100644
@overload
@overload
def encode(
def encode(
@@ -586,6 +6
86
,7 @@
class MQLLMEngineClient(EngineClient):
@@ -586,6 +6
94
,7 @@
class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
priority: int = 0,
...
@@ -3207,7 +3220,7 @@ index 85b5f31e..da207947 100644
...
@@ -3207,7 +3220,7 @@ index 85b5f31e..da207947 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -630,6 +73
1
,12 @@
class MQLLMEngineClient(EngineClient):
@@ -630,6 +73
9
,12 @@
class MQLLMEngineClient(EngineClient):
else:
else:
lp_bytes = None
lp_bytes = None
...
@@ -3220,7 +3233,7 @@ index 85b5f31e..da207947 100644
...
@@ -3220,7 +3233,7 @@ index 85b5f31e..da207947 100644
request_bytes = pickle.dumps(
request_bytes = pickle.dumps(
RPCProcessRequest(
RPCProcessRequest(
prompt=prompt,
prompt=prompt,
@@ -639,11 +74
6
,11 @@
class MQLLMEngineClient(EngineClient):
@@ -639,11 +7
5
4,11 @@
class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
priority=priority,
...
@@ -3234,7 +3247,7 @@ index 85b5f31e..da207947 100644
...
@@ -3234,7 +3247,7 @@ index 85b5f31e..da207947 100644
await self.input_socket.send_multipart(parts, copy=False)
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
# 4) Stream the RequestOutputs from the output queue. Note
@@ -705,3 +8
1
2,6 @@
class MQLLMEngineClient(EngineClient):
@@ -705,3 +82
0
,6 @@
class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
if isinstance(request_output, BaseException):
raise request_output
raise request_output
...
@@ -3242,10 +3255,10 @@ index 85b5f31e..da207947 100644
...
@@ -3242,10 +3255,10 @@ index 85b5f31e..da207947 100644
+ def set_metrics_publisher(self, metrics_publisher):
+ def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher
+ self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..
dbd9d58d
100644
index a0dd7958..
3204cfb8
100644
--- a/vllm/engine/multiprocessing/engine.py
--- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py
@@ -3,35 +3,
73
@@
@@ -3,35 +3,
115
@@
import pickle
import pickle
import signal
import signal
from contextlib import contextmanager
from contextlib import contextmanager
...
@@ -3305,26 +3318,68 @@ index a0dd7958..dbd9d58d 100644
...
@@ -3305,26 +3318,68 @@ index a0dd7958..dbd9d58d 100644
+ self.metrics_socket = metrics_socket
+ self.metrics_socket = metrics_socket
+
+
+ # KV metrics
+ # KV metrics
+ self._send_kv_metrics(0, 0)
+ self._send_kv_metrics(0,
0, 0, 0.0, 0.
0)
+
+
+ def log(self, stats: Stats) -> None:
+ def log(self, stats: Stats) -> None:
+ self._send_kv_metrics(
+ self._send_kv_metrics(
+ stats.num_running_sys,
+ stats.num_running_sys,
+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks)
+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks),
+ stats.num_waiting_sys,
+ stats.gpu_cache_usage_sys,
+ stats.gpu_prefix_cache_hit_rate
+ )
+ )
+
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+ pass
+
+
+ def _send_kv_metrics(self, active_slots, active_kv_blocks):
+ def _send_kv_metrics(
+ self,
+ active_slots,
+ active_kv_blocks,
+ num_requests_waiting,
+ gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate,
+ ):
+ if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(
+ KvMetrics(
+ active_slots,
+ self.request_total_slots,
+ active_kv_blocks,
+ self.kv_total_blocks,
+ num_requests_waiting,
+ gpu_cache_usage_perc,
+ gpu_prefix_cache_hit_rate,
+ )
+ )
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+# TODO: Send entire stats object to the client
+class StatLogger(StatLoggerBase):
+ def __init__(
+ self,
+ metrics_socket
+ ):
+ self.metrics_socket = metrics_socket
+
+ def log(self, stats: Stats) -> None:
+ self._send_metrics(stats)
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+
+ def _send_metrics(self, stats: Stats):
+ if not self.metrics_socket.closed:
+ if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(
KvMetrics(active_slots, self.request_total_slots, active_kv_blocks, self.kv_total_blocks)
)
+ metrics_bytes = pickle.dumps(
stats
)
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+
+
+
+
class MQLLMEngine:
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
"""A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +1
32
,3
1
@@
class MQLLMEngine:
@@ -94,12 +1
74
,3
5
@@
class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
...
@@ -3346,17 +3401,21 @@ index a0dd7958..dbd9d58d 100644
...
@@ -3346,17 +3401,21 @@ index a0dd7958..dbd9d58d 100644
+
+
+
+
+ # Attach logger for continuous metrics publishing
+ # Attach logger for continuous metrics publishing
+ self.stat_logger = KvStatLogger(
+ self.
kv_
stat_logger = KvStatLogger(
+ self.engine.scheduler_config.max_num_seqs,
+ self.engine.scheduler_config.max_num_seqs,
+ self.engine.cache_config.num_gpu_blocks,
+ self.engine.cache_config.num_gpu_blocks,
+ self.metrics_socket
+ self.metrics_socket
+ )
+ )
+ self.engine.add_logger("kv_metrics", self.stat_logger)
+ self.general_stat_logger = StatLogger(
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.kv_stat_logger)
+ self.engine.add_logger("general_metrics", self.general_stat_logger)
+
+
@property
@property
def dead_error(self) -> BaseException:
def dead_error(self) -> BaseException:
if self._errored_with is not None:
if self._errored_with is not None:
@@ -171,8 +2
28
,17 @@
class MQLLMEngine:
@@ -171,8 +2
74
,17 @@
class MQLLMEngine:
# Handle the query from the Client.
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
tracing_enabled = self.engine.is_tracing_enabled()
...
@@ -3376,7 +3435,7 @@ index a0dd7958..dbd9d58d 100644
...
@@ -3376,7 +3435,7 @@ index a0dd7958..dbd9d58d 100644
except Exception as e:
except Exception as e:
response = e
response = e
@@ -185,6 +2
51
,7 @@
class MQLLMEngine:
@@ -185,6 +2
97
,7 @@
class MQLLMEngine:
while True:
while True:
if not self.engine.has_unfinished_requests():
if not self.engine.has_unfinished_requests():
...
@@ -3384,7 +3443,7 @@ index a0dd7958..dbd9d58d 100644
...
@@ -3384,7 +3443,7 @@ index a0dd7958..dbd9d58d 100644
# Poll until there is work to do.
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
# When there's no work, check on engine health and send
@@ -220,6 +
287
,13 @@
class MQLLMEngine:
@@ -220,6 +
333
,13 @@
class MQLLMEngine:
def handle_new_input(self):
def handle_new_input(self):
"""Handle new input from the socket"""
"""Handle new input from the socket"""
try:
try:
...
@@ -3398,7 +3457,7 @@ index a0dd7958..dbd9d58d 100644
...
@@ -3398,7 +3457,7 @@ index a0dd7958..dbd9d58d 100644
while self.input_socket.poll(timeout=0) != 0:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
request = pickle.loads(frames[0].buffer)
@@ -262,6 +3
36
,11 @@
class MQLLMEngine:
@@ -262,6 +3
82
,11 @@
class MQLLMEngine:
self._send_outputs(rpc_err)
self._send_outputs(rpc_err)
try:
try:
...
@@ -3410,7 +3469,7 @@ index a0dd7958..dbd9d58d 100644
...
@@ -3410,7 +3469,7 @@ index a0dd7958..dbd9d58d 100644
self.engine.add_request(
self.engine.add_request(
request_id=request_id,
request_id=request_id,
prompt=request.prompt,
prompt=request.prompt,
@@ -269,7 +34
8
,9 @@
class MQLLMEngine:
@@ -269,7 +3
9
4,9 @@
class MQLLMEngine:
lora_request=request.lora_request,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
prompt_adapter_request=request.prompt_adapter_request,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment