async_llm_engine.py 50.3 KB
Newer Older
1
2
import asyncio
import time
3
import weakref
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
7
from weakref import ReferenceType
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
16
from vllm.engine.metrics_types import StatLoggerBase
17
from vllm.engine.protocol import EngineClient
18
from vllm.executor.executor_base import ExecutorAsyncBase
19
from vllm.executor.gpu_executor import GPUExecutorAsync
20
from vllm.executor.ray_utils import initialize_ray_cluster
21
from vllm.inputs import PromptType
22
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.logger import init_logger
24
from vllm.lora.request import LoRARequest
25
26
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
27
from vllm.model_executor.layers.sampler import SamplerOutput
28
from vllm.outputs import PoolingRequestOutput, RequestOutput
29
from vllm.pooling_params import PoolingParams
30
from vllm.prompt_adapter.request import PromptAdapterRequest
31
from vllm.sampling_params import SamplingParams
32
from vllm.sequence import ExecuteModelRequest
33
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
34
from vllm.usage.usage_lib import UsageContext
35
from vllm.utils import deprecate_kwargs, weak_bind
36
37

logger = init_logger(__name__)
38
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
39

Antoni Baum's avatar
Antoni Baum committed
40

41
42
43
44
class AsyncEngineDeadError(RuntimeError):
    pass


45
46
47
48
49
50
51
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
52
53

    exception = None
54
    try:
55
56
57
58
59
60
61
62
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
63
64
65
66
67
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
68
            "Task finished unexpectedly. This should never happen! "
69
            "Please open an issue on Github. See stack trace above for the "
70
            "actual cause.") from e
71
72


73
74
75
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
76
class AsyncStream:
77
    """A stream of RequestOutputs or PoolingRequestOutputs for a request
78
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
79

80
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
81
        self.request_id = request_id
82
        self._cancel = cancel
83
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
84
85
        self._finished = False

86
    def put(self, item: Union[RequestOutput, PoolingRequestOutput,
87
                              Exception]) -> None:
88
89
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
90

91
92
93
94
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
95
96
97
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
98
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
99
100
101
102
103

    @property
    def finished(self) -> bool:
        return self._finished

104
105
    async def generator(
        self
106
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
107
        try:
108
            while True:
109
                result = await self._queue.get()
110
                if self._is_raisable(result):
111
112
113
114
115
116
117
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
118

119
120
121
122
123
124
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
125

126
127
128
129
130
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
131
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
132
133
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
134
        self.new_requests_event = asyncio.Event()
135
136
137
138

    def __contains__(self, item):
        return item in self._request_streams

139
140
    def __len__(self) -> int:
        return len(self._request_streams)
141
142
143
144
145
146
147

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
148
            self.abort_request(request_id, exception=exc)
149
        else:
150
            # NB: tuple() used here because self.abort_request pops the stream
151
            # out of self._request_streams, so we can't iterate on it directly
152
153
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
154
155

    def process_request_output(self,
156
                               request_output: Union[RequestOutput,
157
                                                     PoolingRequestOutput],
158
159
160
161
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
162
        finished = request_output.finished
163

164
165
166
167
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
168
169
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
170
        if stream is not None:
171
            stream.put(request_output)
172
173
174
175
176
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
177

178
179
    def process_exception(self,
                          request_id: str,
180
                          exception: BaseException,
181
182
183
184
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
185
            logger.info("Finished request %s.", request_id)
186
        self.abort_request(request_id, exception=exception)
187

188
189
190
191
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
192
193
194
195
196
197
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

198
199
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
200
201
202
203
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
204
205
206

        self.new_requests_event.set()

207
208
209
        if verbose:
            logger.info("Added request %s.", request_id)

210
211
        return stream

212
213
214
    def abort_request(self,
                      request_id: str,
                      *,
215
216
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
217
                      verbose: bool = False) -> None:
218
219
        """Abort a request during next background loop iteration."""
        if verbose:
220
            logger.info("Aborted request %s.", request_id)
221

222
        self._aborted_requests.put_nowait(request_id)
223

224
225
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
226
            stream.finish(exception=exception)
227

228
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
229
230
        """Get the new requests and finished requests to be
        sent to the engine."""
231
        new_requests: List[Dict] = []
232
233
        finished_requests: Set[str] = set()

234
235
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
236
237
238
239
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
240
241
            request_id = stream.request_id
            if request_id in finished_requests:
242
                # The request has already been aborted.
243
244
245
246
247
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
248
249

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
250

251
    async def wait_for_new_requests(self):
252
253
254
255
256
257
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
258

Antoni Baum's avatar
Antoni Baum committed
259
260
261
262

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

263
264
265
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

266
    async def step_async(
267
        self, virtual_engine: int
268
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
269
270
271
272
273
274
275
276
277
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
278
279
280
281
282
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
283
284
        allow_async_output_proc = cached_outputs.allow_async_output_proc

285
286
        ctx = self.scheduler_contexts[virtual_engine]

287
288
289
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

290
291
292
293
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
294

295
            # Schedule iteration
296
297
298
299
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

300
301
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
302

303
304
305
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

306
            # Maybe switch from async mode to sync mode
307
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
308
                self._process_model_outputs(ctx=ctx)
309

310
311
312
313
314
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
315
316
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
317
318
        else:
            finished_requests_ids = list()
319
320
321

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
322

323
        if not scheduler_outputs.is_empty():
324
325
326
327
328
329
330
331

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

332
333
334
335
336
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
337
                virtual_engine=virtual_engine,
338
339
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
340
341
342
343
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
344
345

            if allow_async_output_proc:
346
347
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
348

349
            # Execute the model.
350
            outputs = await self.model_executor.execute_model_async(
351
                execute_model_req)
352

353
354
355
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
356
                self._update_cached_scheduler_output(virtual_engine, outputs)
357
        else:
358
359
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
360
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
361

362
363
364
365
366
367
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
368
            # Clear the cache if we have finished all the steps
369
370
371
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
372

373
374
375
376
377
378
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

379
380
381
382
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
383
384
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
385

386
            if outputs and allow_async_output_proc:
387
                assert len(
388
                    outputs
389
390
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
391
                    outputs[0], seq_group_metadata_list,
392
                    scheduler_outputs.scheduled_seq_groups)
393
394

            if not allow_async_output_proc:
395
                self._process_model_outputs(ctx=ctx)
396
397

                # Log stats.
398
                self.do_log_stats(scheduler_outputs, outputs)
399
400
401
402
403

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
404
            # Multi-step case
405
            return ctx.request_outputs
406
407
408
409

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
410
                self._process_model_outputs(ctx=ctx)
411
            assert len(ctx.output_queue) == 0
412

413
        return ctx.request_outputs
414

415
416
417
418
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

419
420
421
422
423
424
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

425
    @overload  # DEPRECATED
426
    async def add_request_async(
427
428
        self,
        request_id: str,
429
430
        *,
        inputs: PromptType,
431
432
433
434
435
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
436
        priority: int = 0,
437
438
439
440
441
442
443
444
445
446
447
448
449
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
450
        priority: int = 0,
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
467
            priority: int = 0,
468
469
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
470
    ) -> None:
471
        """Async version of :meth:`add_request`."""
472
473
474
475
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

476
477
478
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
479
480
481
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
482
483
        if arrival_time is None:
            arrival_time = time.time()
484

485
486
487
488
        if self.tokenizer is not None:
            tokenizer = await self.get_tokenizer_async(lora_request)
            self._validate_token_prompt(prompt, tokenizer=tokenizer)

489
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
490
            prompt,
491
492
            request_id=request_id,
            lora_request=lora_request,
493
494
            prompt_adapter_request=prompt_adapter_request,
        )
495
        processed_inputs = self.input_processor(preprocessed_inputs)
496

497
498
499
500
501
502
503
504
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
505
                tokenizer=await self.get_tokenizer_async(lora_request),
506
507
508
                default_guided_backend=self.decoding_config.
                guided_decoding_backend)

509
        self._add_processed_request(
510
            request_id=request_id,
511
512
513
514
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
515
            prompt_adapter_request=prompt_adapter_request,
516
            trace_headers=trace_headers,
517
            priority=priority,
518
        )
519

520
    async def check_health_async(self) -> None:
521
522
        if self.tokenizer:
            self.tokenizer.check_health()
523
        self.model_executor.check_health()
524

525

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
        default_guided_backend: str) -> SamplingParams:
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
    if (guided_decoding := sampling_params.guided_decoding) is None:
        return sampling_params

    logger.debug("Building guided decoding logits processor. "
                 "Params: %s", guided_decoding)

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
        guided_params=guided_decoding, tokenizer=tokenizer)

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


556
class AsyncLLMEngine(EngineClient):
557
    """An asynchronous wrapper for :class:`LLMEngine`.
558

559
560
561
562
563
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
564
565

    Args:
566
        log_requests: Whether to log the requests.
567
568
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
569
570
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
571
    """
572

Antoni Baum's avatar
Antoni Baum committed
573
574
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

575
576
577
    def __init__(self,
                 *args,
                 log_requests: bool = True,
578
                 start_engine_loop: bool = True,
579
                 **kwargs) -> None:
580
        self.log_requests = log_requests
581
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
582

583
584
585
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
586
587
588
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

589
590
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
591
                weak_bind(self.process_request_outputs)
592

593
        self.background_loop: Optional[asyncio.Future] = None
594
595
596
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
597
        self._background_loop_unshielded: Optional[asyncio.Task] = None
598
        self.start_engine_loop = start_engine_loop
599
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
600

601
602
603
        # Lazy initialized fields
        self._request_tracker: RequestTracker

604
605
606
607
608
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

609
    @classmethod
610
    def _get_executor_cls(
611
            cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
612
613
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
614
615
616
617
618
619
620
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
621
622
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
623
        elif engine_config.device_config.device_type == "tpu":
624
625
626
627
628
629
630
            if distributed_executor_backend == "ray":
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
631
632
633
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
634
635
636
637
638
639
640
641
        elif engine_config.device_config.device_type == "hpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
                executor_class = RayHPUExecutorAsync
            else:
                from vllm.executor.hpu_executor import HPUExecutorAsync
                executor_class = HPUExecutorAsync
642
643
644
645
646
647
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
648
649
650
651
652
653
654
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
655
656
657
658
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
659
660
661
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
662
        elif distributed_executor_backend == "ray":
663
664
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
665
666
667
668
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
669
670
671
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
672
673
674
675
676
677
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
678
        engine_config: Optional[VllmConfig] = None,
679
680
681
682
683
684
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
685
        if engine_config is None:
686
            engine_config = engine_args.create_engine_config(usage_context)
687
688
689

        executor_class = cls._get_executor_cls(engine_config)

690
691
692
        if executor_class.uses_ray:
            initialize_ray_cluster(engine_config.parallel_config)

693
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
694
        engine = cls(
695
            vllm_config=engine_config,
696
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
697
698
699
700
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
701
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
702
        )
703
704
        return engine

705
706
    @property
    def is_running(self) -> bool:
707
        return (self.background_loop is not None
708
                and self._background_loop_unshielded is not None
709
710
711
712
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
713
714
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
715
716
717
718
719
720
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

721
    @property
722
723
724
725
726
727
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
728

729
730
731
732
733
734
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
735

736
737
738
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

739
740
741
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
742
    ) -> AnyTokenizer:
743
        return await self.engine.get_tokenizer_async(lora_request)
744

745
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
746
        """Start the background loop."""
747
748
749
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
750
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
751
            raise RuntimeError("Background loop is already running.")
752
753
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
754
755

        self._background_loop_unshielded = asyncio.get_event_loop(
756
        ).create_task(self.run_engine_loop(weakref.ref(self)))
757
        self._background_loop_unshielded.add_done_callback(
758
            partial(_log_task_completion, error_callback=self._error_callback))
759
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
760

761
762
763
764
765
766
767
768
769
770
771
772
773
774
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

775
    async def engine_step(self, virtual_engine: int) -> bool:
776
777
778
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
779

780
781
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
782
783
784

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
785
            try:
786
                await self.engine.add_request_async(**new_request)
787
788
789
790
791
792
793
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
794

795
796
        if aborted_requests:
            await self._engine_abort(aborted_requests)
797

798
        request_outputs = await self.engine.step_async(virtual_engine)
799

Antoni Baum's avatar
Antoni Baum committed
800
        # Put the outputs into the corresponding streams.
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
816
        for request_output in request_outputs:
817
            self._request_tracker.process_request_output(
818
                request_output, verbose=self.log_requests)
819
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
820

821
        return all_finished
822

Antoni Baum's avatar
Antoni Baum committed
823
    async def _engine_abort(self, request_ids: Iterable[str]):
824
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
825

826
827
828
829
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
830
        engine: Optional[AsyncLLMEngine] = engine_ref()
831
832
833
        if not engine:
            return

834
        pipeline_parallel_size = \
835
                engine.engine.parallel_config.pipeline_parallel_size
836
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
837
        while True:
838
            if not any(has_requests_in_progress):
839
                logger.debug("Waiting for new requests...")
840
841
842
843
844
845
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
846
847
848
849
850
851
852
853
854
855
856
857
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
858
                logger.debug("Got new requests!")
859
                requests_in_progress = [
860
                    asyncio.create_task(engine.engine_step(ve))
861
862
863
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
864
865
866
867

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
868
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
869
870
871
872
873
874
875
876
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
877
                    has_unfinished_requests = (
878
879
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
880
                            virtual_engine))
881
882
883
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
884
                                engine.engine_step(virtual_engine)))
885
886
887
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
888
889
890
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
891
                engine.set_errored(exc)
892
                raise
Antoni Baum's avatar
Antoni Baum committed
893
894
            await asyncio.sleep(0)

895
896
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
897
898
    @overload  # DEPRECATED
    def add_request(
899
900
        self,
        request_id: str,
901
902
        *,
        inputs: PromptType,
903
        params: Union[SamplingParams, PoolingParams],
904
905
906
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
907
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
908
        priority: int = 0,
909
    ) -> Coroutine[None, None, AsyncGenerator[Union[
910
            RequestOutput, PoolingRequestOutput], None]]:
911
912
913
914
915
916
917
918
919
920
921
922
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
923
        priority: int = 0,
924
    ) -> Coroutine[None, None, AsyncGenerator[Union[
925
            RequestOutput, PoolingRequestOutput], None]]:
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
941
        priority: int = 0,
942
943
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
944
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
945
946
947
948
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

949
        if not self.is_running:
950
951
952
953
954
955
956
957
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
958

959
960
961
962
963
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

964
        stream = self._request_tracker.add_request(
965
            request_id,
966
            verbose=self.log_requests,
967
            prompt=prompt,
968
            params=params,
969
            arrival_time=arrival_time or time.time(),
970
            lora_request=lora_request,
971
            trace_headers=trace_headers,
972
973
974
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
975

976
        return stream.generator()
977

978
    async def generate(
979
        self,
980
        prompt: PromptType,
981
982
        sampling_params: SamplingParams,
        request_id: str,
983
        lora_request: Optional[LoRARequest] = None,
984
        trace_headers: Optional[Mapping[str, str]] = None,
985
986
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
987
    ) -> AsyncGenerator[RequestOutput, None]:
988
989
990
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
991
992
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
993
994

        Args:
995
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
996
                for more details about the format of each input.
997
998
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
999
            lora_request: LoRA request to use for generation, if any.
1000
            trace_headers: OpenTelemetry trace headers.
1001
            prompt_adapter_request: Prompt Adapter request to use
1002
                                            for generation, if any.
1003
1004
            priority: The priority of the request.
                Only applicable with priority scheduling.
1005
1006

        Yields:
1007
1008
            The output `RequestOutput` objects from the LLMEngine
            for the request.
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1026
            >>> # note that engine_args here is AsyncEngineArgs instance
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1053
        """
1054
        async for output in await self.add_request(
1055
                request_id,
1056
                prompt,
1057
                sampling_params,
1058
                lora_request=lora_request,
1059
                trace_headers=trace_headers,
1060
                prompt_adapter_request=prompt_adapter_request,
1061
                priority=priority,
1062
        ):
1063
            yield LLMEngine.validate_output(output, RequestOutput)
1064
1065
1066

    async def encode(
        self,
1067
        prompt: PromptType,
1068
1069
1070
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1071
        trace_headers: Optional[Mapping[str, str]] = None,
1072
        priority: int = 0,
1073
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
1074
1075
1076
1077
1078
1079
1080
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1081
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1082
                for more details about the format of each input.
1083
1084
1085
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1086
            trace_headers: OpenTelemetry trace headers.
1087
1088
            priority: The priority of the request.
                Only applicable with priority scheduling.
1089
1090

        Yields:
1091
            The output `PoolingRequestOutput` objects from the LLMEngine
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1110
            >>> # note that engine_args here is AsyncEngineArgs instance
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1136
        async for output in await self.add_request(
1137
                request_id,
1138
                prompt,
1139
                pooling_params,
1140
                lora_request=lora_request,
1141
                trace_headers=trace_headers,
1142
                priority=priority,
1143
        ):
1144
            yield LLMEngine.validate_output(output, PoolingRequestOutput)
1145

Antoni Baum's avatar
Antoni Baum committed
1146
1147
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1148

Antoni Baum's avatar
Antoni Baum committed
1149
1150
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1151

Antoni Baum's avatar
Antoni Baum committed
1152
1153
1154
        Args:
            request_id: The unique id of the request.
        """
1155
1156
1157
1158
1159
1160
1161
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1162
        return self._abort(request_id)
1163

Antoni Baum's avatar
Antoni Baum committed
1164
    def _abort(self, request_id: str) -> None:
1165
1166
1167
1168
1169
1170
1171
1172
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1173
        self._request_tracker.abort_request(request_id,
1174
                                            exception=asyncio.CancelledError,
1175
                                            verbose=self.log_requests)
1176

1177
1178
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1179
        return self.engine.get_model_config()
1180

1181
1182
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1183
        return self.engine.get_parallel_config()
1184

1185
1186
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1187
        return self.engine.get_decoding_config()
1188

1189
1190
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1191
        return self.engine.get_scheduler_config()
1192
1193
1194

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1195
        return self.engine.get_lora_config()
1196

1197
1198
1199
1200
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1201
        self.engine.do_log_stats()
1202

1203
    async def check_health(self) -> None:
1204
1205
1206
1207
1208
1209
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1210
        await self.engine.check_health_async()
1211
        logger.debug("Health check took %fs", time.perf_counter() - t)
1212
1213

    async def is_tracing_enabled(self) -> bool:
1214
        return self.engine.is_tracing_enabled()
1215
1216

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1217
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1218
1219

    def remove_logger(self, logger_name: str) -> None:
1220
        self.engine.remove_logger(logger_name=logger_name)
1221
1222

    async def start_profile(self) -> None:
1223
1224
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1225
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1226
1227
1228
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1229
1230

    async def stop_profile(self) -> None:
1231
1232
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1233
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1234
1235
1236
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")