Unverified Commit 6f7de33b authored by Mark McLoughlin's avatar Mark McLoughlin Committed by GitHub
Browse files

[Metrics] Refactor LoRA state tracking (#26801)


Signed-off-by: default avatarMark McLoughlin <markmc@redhat.com>
parent a98cc35c
...@@ -15,12 +15,19 @@ from tests.v1.engine.utils import ( ...@@ -15,12 +15,19 @@ from tests.v1.engine.utils import (
) )
from vllm import PoolingParams from vllm import PoolingParams
from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import (
EngineCoreEvent,
EngineCoreEventType,
EngineCoreOutputs,
EngineCoreRequest,
FinishReason,
)
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
def _ref_convert_id_to_token( def _ref_convert_id_to_token(
...@@ -895,6 +902,170 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -895,6 +902,170 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active assert iteration_stats.num_generation_tokens == num_active
@pytest.mark.parametrize("log_stats", [True, False])
def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
"""Test LoRA request lifecycle tracking through waiting -> running -> finished."""
output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=log_stats
)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()
# Create LoRA requests
lora1 = LoRARequest(lora_name="lora-1", lora_int_id=1, lora_path="/path/to/lora1")
lora2 = LoRARequest(lora_name="lora-2", lora_int_id=2, lora_path="/path/to/lora2")
# Create requests with different LoRA adapters:
# - request-0: lora-1
# - request-1: lora-2
# - request-2: None (no LoRA)
lora_assignments = [lora1, lora2, None]
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=lora_assignments[idx],
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(),
pooling_params=None,
)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
# Add all requests to the OutputProcessor
for request in requests:
output_processor.add_request(request, None)
# First iteration: process outputs with QUEUED events
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
for output in outputs.outputs:
output.events = [
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp)
]
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
if log_stats:
# Verify waiting counts
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0
# Verify internal state
assert len(output_processor.lora_states.requests) == 2
assert "lora-1" in output_processor.lora_states.requests
assert "lora-2" in output_processor.lora_states.requests
else:
# When log_stats=False, no tracking should occur
assert iteration_stats is None
assert len(output_processor.lora_states.requests) == 0
# Second iteration: process outputs with SCHEDULED events
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
for output in outputs.outputs:
output.events = [
EngineCoreEvent.new_event(
EngineCoreEventType.SCHEDULED, engine_core_timestamp
)
]
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
if log_stats:
# Verify running counts
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
else:
assert iteration_stats is None
assert len(output_processor.lora_states.requests) == 0
# Third iteration: finish request-0 (lora-1)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-0 as finished (it uses lora-1)
for output in outputs.outputs:
if output.request_id == "request-0":
output.finish_reason = FinishReason.LENGTH
break
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
if log_stats:
# lora-1 should be removed since no requests remain
assert "lora-1" not in output_processor.lora_states.requests
# lora-2 should still be running
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
assert len(output_processor.lora_states.requests) == 1
else:
assert len(output_processor.lora_states.requests) == 0
# Fourth iteration: finish request-1 (lora-2)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-1 as finished (it uses lora-2)
for output in outputs.outputs:
if output.request_id == "request-1":
output.finish_reason = FinishReason.LENGTH
break
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
if log_stats:
# lora-2 should be removed since no requests remain
assert "lora-2" not in output_processor.lora_states.requests
assert len(outputs.scheduler_stats.running_lora_adapters) == 0
assert len(output_processor.lora_states.requests) == 0
else:
assert len(output_processor.lora_states.requests) == 0
# Finish the last request (no LoRA)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-2 as finished (it has no LoRA)
for output in outputs.outputs:
if output.request_id == "request-2":
output.finish_reason = FinishReason.LENGTH
break
iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)
# Verify all requests are finished
assert output_processor.get_num_unfinished_requests() == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_output_collector(): async def test_request_output_collector():
NUM_REQS = 3 NUM_REQS = 3
......
...@@ -5,20 +5,4 @@ from vllm.v1.metrics.stats import IterationStats ...@@ -5,20 +5,4 @@ from vllm.v1.metrics.stats import IterationStats
def test_iteration_stats_repr(): def test_iteration_stats_repr():
iteration_stats = IterationStats() iteration_stats = IterationStats()
iteration_stats.iteration_timestamp = 0 assert repr(iteration_stats).startswith("IterationStats(")
expected_repr = (
"IterationStats("
"iteration_timestamp=0, "
"num_generation_tokens=0, "
"num_prompt_tokens=0, "
"num_preempted_reqs=0, "
"finished_requests=[], "
"max_num_generation_tokens_iter=[], "
"n_params_iter=[], "
"time_to_first_tokens_iter=[], "
"inter_token_latencies_iter=[], "
"waiting_lora_adapters={}, "
"running_lora_adapters={}, "
"num_corrupted_reqs=0)"
)
assert repr(iteration_stats) == expected_repr
...@@ -508,6 +508,8 @@ class AsyncLLM(EngineClient): ...@@ -508,6 +508,8 @@ class AsyncLLM(EngineClient):
processed_outputs.reqs_to_abort processed_outputs.reqs_to_abort
) )
output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
......
...@@ -289,6 +289,7 @@ class LLMEngine: ...@@ -289,6 +289,7 @@ class LLMEngine:
engine_core_timestamp=outputs.timestamp, engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
) )
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Abort any reqs that finished due to stop strings. # 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort) self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
......
...@@ -22,7 +22,12 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason ...@@ -22,7 +22,12 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats from vllm.v1.metrics.stats import (
IterationStats,
LoRARequestStates,
RequestStateStats,
SchedulerStats,
)
class RequestOutputCollector: class RequestOutputCollector:
...@@ -310,7 +315,7 @@ class OutputProcessor: ...@@ -310,7 +315,7 @@ class OutputProcessor:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {} self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {} self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates() self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None self.tracer: Tracer | None = None
def get_num_unfinished_requests(self): def get_num_unfinished_requests(self):
...@@ -334,7 +339,7 @@ class OutputProcessor: ...@@ -334,7 +339,7 @@ class OutputProcessor:
for request_id in request_ids: for request_id in request_ids:
req_state = self.request_states.pop(request_id, None) req_state = self.request_states.pop(request_id, None)
if req_state is not None: if req_state is not None:
self.lora_states.abort_request(req_state) self.lora_states.request_finished(request_id, req_state.lora_name)
request_ids_to_abort.append(request_id) request_ids_to_abort.append(request_id)
# Produce final abort output. # Produce final abort output.
if req_state.queue is not None and ( if req_state.queue is not None and (
...@@ -382,7 +387,6 @@ class OutputProcessor: ...@@ -382,7 +387,6 @@ class OutputProcessor:
log_stats=self.log_stats, log_stats=self.log_stats,
) )
self.request_states[request_id] = req_state self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req: if parent_req:
self.parent_requests[parent_req.request_id] = parent_req self.parent_requests[parent_req.request_id] = parent_req
...@@ -484,13 +488,15 @@ class OutputProcessor: ...@@ -484,13 +488,15 @@ class OutputProcessor:
) )
if self.tracer: if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats) self.do_tracing(engine_core_output, req_state, iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)
return OutputProcessorOutput( return OutputProcessorOutput(
request_outputs=request_outputs, request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort, reqs_to_abort=reqs_to_abort,
) )
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats)
def do_tracing( def do_tracing(
self, self,
engine_core_output: EngineCoreOutput, engine_core_output: EngineCoreOutput,
...@@ -564,8 +570,6 @@ class OutputProcessor: ...@@ -564,8 +570,6 @@ class OutputProcessor:
if iteration_stats is None: if iteration_stats is None:
return return
lora_stats = self.lora_states.get_stats(req_state)
assert engine_core_timestamp is not None assert engine_core_timestamp is not None
assert req_state.stats is not None assert req_state.stats is not None
iteration_stats.update_from_output( iteration_stats.update_from_output(
...@@ -574,7 +578,8 @@ class OutputProcessor: ...@@ -574,7 +578,8 @@ class OutputProcessor:
req_state.is_prefilling, req_state.is_prefilling,
req_state.prompt_len, req_state.prompt_len,
req_state.stats, req_state.stats,
lora_stats, self.lora_states,
req_state.lora_name,
) )
def _update_stats_from_finished( def _update_stats_from_finished(
...@@ -596,7 +601,7 @@ class OutputProcessor: ...@@ -596,7 +601,7 @@ class OutputProcessor:
max_tokens_param=req_state.max_tokens_param, max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats, req_stats=req_state.stats,
) )
self.lora_states.finish_request(req_state) self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
ParentRequest.observe_finished_request( ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
......
...@@ -989,6 +989,20 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -989,6 +989,20 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
scheduler_stats.kv_connector_stats, engine_idx scheduler_stats.kv_connector_stats, engine_idx
) )
if self.gauge_lora_info is not None:
running_lora_adapters = ",".join(
scheduler_stats.running_lora_adapters.keys()
)
waiting_lora_adapters = ",".join(
scheduler_stats.waiting_lora_adapters.keys()
)
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
if mm_cache_stats is not None: if mm_cache_stats is not None:
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
...@@ -1055,20 +1069,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ...@@ -1055,20 +1069,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
finished_request.max_tokens_param finished_request.max_tokens_param
) )
if self.gauge_lora_info is not None:
running_lora_adapters = ",".join(
iteration_stats.running_lora_adapters.keys()
)
waiting_lora_adapters = ",".join(
iteration_stats.waiting_lora_adapters.keys()
)
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
def record_sleep_state(self, sleep: int = 0, level: int = 0): def record_sleep_state(self, sleep: int = 0, level: int = 0):
awake = 1 awake = 1
discard_all = 0 discard_all = 0
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from collections import deque from collections import defaultdict, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
...@@ -11,7 +11,6 @@ from vllm.v1.spec_decode.metrics import SpecDecodingStats ...@@ -11,7 +11,6 @@ from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.engine.output_processor import RequestState
@dataclass @dataclass
...@@ -170,11 +169,8 @@ class SchedulerStats: ...@@ -170,11 +169,8 @@ class SchedulerStats:
spec_decoding_stats: SpecDecodingStats | None = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: dict[str, Any] | None = None kv_connector_stats: dict[str, Any] | None = None
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
@dataclass running_lora_adapters: dict[str, int] = field(default_factory=dict)
class LoRAStats:
waiting_requests: set[str] = field(default_factory=set)
running_requests: set[str] = field(default_factory=set)
@dataclass @dataclass
...@@ -229,8 +225,6 @@ class IterationStats: ...@@ -229,8 +225,6 @@ class IterationStats:
self.n_params_iter: list[int] = [] self.n_params_iter: list[int] = []
self.time_to_first_tokens_iter: list[float] = [] self.time_to_first_tokens_iter: list[float] = []
self.inter_token_latencies_iter: list[float] = [] self.inter_token_latencies_iter: list[float] = []
self.waiting_lora_adapters: dict[str, int] = {}
self.running_lora_adapters: dict[str, int] = {}
self.num_corrupted_reqs: int = 0 self.num_corrupted_reqs: int = 0
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -248,7 +242,8 @@ class IterationStats: ...@@ -248,7 +242,8 @@ class IterationStats:
is_prefilling: bool, is_prefilling: bool,
prompt_len: int, prompt_len: int,
req_stats: RequestStateStats, req_stats: RequestStateStats,
lora_stats: LoRAStats | None, lora_states: "LoRARequestStates",
lora_name: str | None,
): ):
num_new_generation_tokens = len(output.new_token_ids) num_new_generation_tokens = len(output.new_token_ids)
...@@ -274,7 +269,12 @@ class IterationStats: ...@@ -274,7 +269,12 @@ class IterationStats:
# Process request-level engine core events # Process request-level engine core events
if output.events is not None: if output.events is not None:
self.update_from_events( self.update_from_events(
output.request_id, output.events, is_prefilling, req_stats, lora_stats output.request_id,
output.events,
is_prefilling,
req_stats,
lora_states,
lora_name,
) )
# Process the batch-level "new tokens" engine core event # Process the batch-level "new tokens" engine core event
...@@ -292,7 +292,8 @@ class IterationStats: ...@@ -292,7 +292,8 @@ class IterationStats:
events: list["EngineCoreEvent"], events: list["EngineCoreEvent"],
is_prefilling: bool, is_prefilling: bool,
req_stats: RequestStateStats, req_stats: RequestStateStats,
lora_stats: LoRAStats | None, lora_states: "LoRARequestStates",
lora_name: str | None,
): ):
# Avoid circular dependency # Avoid circular dependency
from vllm.v1.engine import EngineCoreEventType from vllm.v1.engine import EngineCoreEventType
...@@ -300,15 +301,14 @@ class IterationStats: ...@@ -300,15 +301,14 @@ class IterationStats:
for event in events: for event in events:
if event.type == EngineCoreEventType.QUEUED: if event.type == EngineCoreEventType.QUEUED:
req_stats.queued_ts = event.timestamp req_stats.queued_ts = event.timestamp
if lora_stats is not None: lora_states.request_waiting(req_id, lora_name)
lora_stats.waiting_requests.add(req_id)
elif event.type == EngineCoreEventType.SCHEDULED: elif event.type == EngineCoreEventType.SCHEDULED:
if req_stats.scheduled_ts == 0.0: # ignore preemptions if req_stats.scheduled_ts == 0.0: # ignore preemptions
req_stats.scheduled_ts = event.timestamp req_stats.scheduled_ts = event.timestamp
LoRARequestStates.scheduled_request(lora_stats, req_id) lora_states.request_running(req_id, lora_name)
elif event.type == EngineCoreEventType.PREEMPTED: elif event.type == EngineCoreEventType.PREEMPTED:
self.num_preempted_reqs += 1 self.num_preempted_reqs += 1
LoRARequestStates.preempted_request(lora_stats, req_id) lora_states.request_waiting(req_id, lora_name)
def update_from_finished_request( def update_from_finished_request(
self, self,
...@@ -361,61 +361,60 @@ class IterationStats: ...@@ -361,61 +361,60 @@ class IterationStats:
self.num_corrupted_reqs += 1 self.num_corrupted_reqs += 1
class LoRARequestStates: class LoRAStats:
"""Per-LoRA request state stats.""" """Tracks waiting and running request IDs for a single LoRA."""
def __init__(self): def __init__(self):
self.lora_name_to_stats: dict[str, LoRAStats] = {} self.waiting: set[str] = set()
self.running: set[str] = set()
def get_stats(self, req_state: "RequestState") -> LoRAStats | None: def update(self, req_id: str, waiting: bool, running: bool):
if req_state.lora_name is None: assert not (waiting and running)
return None if waiting:
if req_state.lora_name not in self.lora_name_to_stats: self.waiting.add(req_id)
self.lora_name_to_stats[req_state.lora_name] = LoRAStats() else:
return self.lora_name_to_stats[req_state.lora_name] self.waiting.discard(req_id)
def add_request(self, req_state: "RequestState"): if running:
if (lora_stats := self.get_stats(req_state)) is not None: self.running.add(req_id)
lora_stats.waiting_requests.add(req_state.request_id) else:
self.running.discard(req_id)
def finish_request(self, req_state: "RequestState"): @property
if req_state.lora_name is None: def empty(self) -> bool:
return return not (self.waiting or self.running)
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.running_requests.remove(req_state.request_id)
def abort_request(self, req_state: "RequestState"):
if req_state.lora_name is None:
return
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.waiting_requests.discard(req_state.request_id)
lora_stats.running_requests.discard(req_state.request_id)
# Break the pattern for this lifecycle methods so we can
# call this from IterationStats.update_from_events()
@staticmethod
def scheduled_request(lora_stats: LoRAStats | None, request_id: str):
if lora_stats is None:
return
lora_stats.waiting_requests.remove(request_id)
lora_stats.running_requests.add(request_id)
@staticmethod class LoRARequestStates:
def preempted_request(lora_stats: LoRAStats | None, request_id: str): """A per-LoRA count of running and waiting requests."""
if lora_stats is None:
def __init__(self, log_stats: bool = False):
self.log_stats = log_stats
self.requests: defaultdict[str, LoRAStats] = defaultdict(LoRAStats)
def _request_update(
self, req_id: str, lora_name: str | None, waiting: bool, running: bool
):
if not self.log_stats or lora_name is None:
return return
lora_stats.running_requests.remove(request_id)
lora_stats.waiting_requests.add(request_id)
def update_iteration_stats(self, iteration_stats: IterationStats | None): lora_stats = self.requests[lora_name]
if iteration_stats is None: lora_stats.update(req_id, waiting, running)
if lora_stats.empty:
del self.requests[lora_name]
def request_waiting(self, req_id: str, lora_name: str | None):
self._request_update(req_id, lora_name, waiting=True, running=False)
def request_running(self, req_id: str, lora_name: str | None):
self._request_update(req_id, lora_name, waiting=False, running=True)
def request_finished(self, req_id: str, lora_name: str | None):
self._request_update(req_id, lora_name, waiting=False, running=False)
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
if not self.log_stats or scheduler_stats is None:
return return
for lora_name, stats in self.lora_name_to_stats.items(): for lora_name, stats in self.requests.items():
if stats.waiting_requests: scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
iteration_stats.waiting_lora_adapters[lora_name] = len( scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)
stats.waiting_requests
)
if stats.running_requests:
iteration_stats.running_lora_adapters[lora_name] = len(
stats.running_requests
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment