async_llm_engine.py 50.3 KB
Newer Older
1
2
import asyncio
import time
3
import weakref
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
7
from weakref import ReferenceType
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
16
from vllm.engine.metrics_types import StatLoggerBase
17
from vllm.engine.protocol import EngineClient
18
from vllm.executor.executor_base import ExecutorAsyncBase
19
from vllm.executor.gpu_executor import GPUExecutorAsync
20
from vllm.executor.ray_utils import initialize_ray_cluster
21
from vllm.inputs import PromptType
22
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.logger import init_logger
24
from vllm.lora.request import LoRARequest
25
26
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
27
from vllm.model_executor.layers.sampler import SamplerOutput
28
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
29
from vllm.pooling_params import PoolingParams
30
from vllm.prompt_adapter.request import PromptAdapterRequest
31
from vllm.sampling_params import SamplingParams
32
from vllm.sequence import ExecuteModelRequest
33
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
34
from vllm.usage.usage_lib import UsageContext
35
from vllm.utils import deprecate_kwargs, weak_bind
36
37

logger = init_logger(__name__)
38
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
39

Antoni Baum's avatar
Antoni Baum committed
40

41
42
43
44
class AsyncEngineDeadError(RuntimeError):
    pass


45
46
47
48
49
50
51
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
52
53

    exception = None
54
    try:
55
56
57
58
59
60
61
62
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
63
64
65
66
67
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
68
            "Task finished unexpectedly. This should never happen! "
69
            "Please open an issue on Github. See stack trace above for the "
70
            "actual cause.") from e
71
72


73
74
75
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
76
class AsyncStream:
77
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
78
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
79

80
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
81
        self.request_id = request_id
82
        self._cancel = cancel
83
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
84
85
        self._finished = False

86
87
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
88
89
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
90

91
92
93
94
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
95
96
97
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
98
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
99
100
101
102
103

    @property
    def finished(self) -> bool:
        return self._finished

104
105
106
107
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
108
            while True:
109
                result = await self._queue.get()
110
                if self._is_raisable(result):
111
112
113
114
115
116
117
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
118

119
120
121
122
123
124
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
125

126
127
128
129
130
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
131
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
132
133
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
134
        self.new_requests_event = asyncio.Event()
135
136
137
138

    def __contains__(self, item):
        return item in self._request_streams

139
140
    def __len__(self) -> int:
        return len(self._request_streams)
141
142
143
144
145
146
147

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
148
            self.abort_request(request_id, exception=exc)
149
        else:
150
            # NB: tuple() used here because self.abort_request pops the stream
151
            # out of self._request_streams, so we can't iterate on it directly
152
153
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
154
155

    def process_request_output(self,
156
157
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
158
159
160
161
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
162
        finished = request_output.finished
163

164
165
166
167
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
168
169
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
170
        if stream is not None:
171
            stream.put(request_output)
172
173
174
175
176
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
177

178
179
    def process_exception(self,
                          request_id: str,
180
                          exception: BaseException,
181
182
183
184
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
185
            logger.info("Finished request %s.", request_id)
186
        self.abort_request(request_id, exception=exception)
187

188
189
190
191
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
192
193
194
195
196
197
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

198
199
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
200
201
202
203
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
204
205
206

        self.new_requests_event.set()

207
208
209
        if verbose:
            logger.info("Added request %s.", request_id)

210
211
        return stream

212
213
214
    def abort_request(self,
                      request_id: str,
                      *,
215
216
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
217
                      verbose: bool = False) -> None:
218
219
        """Abort a request during next background loop iteration."""
        if verbose:
220
            logger.info("Aborted request %s.", request_id)
221

222
        self._aborted_requests.put_nowait(request_id)
223

224
225
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
226
            stream.finish(exception=exception)
227

228
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
229
230
        """Get the new requests and finished requests to be
        sent to the engine."""
231
        new_requests: List[Dict] = []
232
233
        finished_requests: Set[str] = set()

234
235
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
236
237
238
239
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
240
241
            request_id = stream.request_id
            if request_id in finished_requests:
242
                # The request has already been aborted.
243
244
245
246
247
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
248
249

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
250

251
    async def wait_for_new_requests(self):
252
253
254
255
256
257
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
258

Antoni Baum's avatar
Antoni Baum committed
259
260
261
262

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

263
264
265
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

266
    async def step_async(
267
268
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
269
270
271
272
273
274
275
276
277
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
278
279
280
281
282
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
283
284
        allow_async_output_proc = cached_outputs.allow_async_output_proc

285
286
        ctx = self.scheduler_contexts[virtual_engine]

287
288
289
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

290
291
292
293
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
294

295
            # Schedule iteration
296
297
298
299
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

300
301
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
302
303

            # Maybe switch from async mode to sync mode
304
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
305
                self._process_model_outputs(ctx=ctx)
306

307
308
309
310
311
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
312
313
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
314
315
316

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
317

318
        if not scheduler_outputs.is_empty():
319
320
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
321
322
323
324
325
326
327
328

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

329
330
331
332
333
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
334
                virtual_engine=virtual_engine,
335
336
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
337
338
339
340
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
341
342

            if allow_async_output_proc:
343
344
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
345

346
            # Execute the model.
347
            outputs = await self.model_executor.execute_model_async(
348
                execute_model_req)
349

350
351
352
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
353
                self._update_cached_scheduler_output(virtual_engine, outputs)
354
        else:
355
356
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
357
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
358

359
360
361
362
363
364
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
365
            # Clear the cache if we have finished all the steps
366
367
368
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
369

370
371
372
373
374
375
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

376
377
378
379
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
380
381
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
382

383
            if outputs and allow_async_output_proc:
384
                assert len(
385
                    outputs
386
387
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
388
                    outputs[0], seq_group_metadata_list,
389
                    scheduler_outputs.scheduled_seq_groups)
390
391

            if not allow_async_output_proc:
392
                self._process_model_outputs(ctx=ctx)
393
394

                # Log stats.
395
                self.do_log_stats(scheduler_outputs, outputs)
396
397
398
399
400

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
401
            # Multi-step case
402
            return ctx.request_outputs
403
404
405
406

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
407
                self._process_model_outputs(ctx=ctx)
408
            assert len(ctx.output_queue) == 0
409

410
        return ctx.request_outputs
411

412
413
414
415
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

416
417
418
419
420
421
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

422
    @overload  # DEPRECATED
423
    async def add_request_async(
424
425
        self,
        request_id: str,
426
427
        *,
        inputs: PromptType,
428
429
430
431
432
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
433
        priority: int = 0,
434
435
436
437
438
439
440
441
442
443
444
445
446
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
447
        priority: int = 0,
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
464
            priority: int = 0,
465
466
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
467
    ) -> None:
468
        """Async version of :meth:`add_request`."""
469
470
471
472
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

473
474
475
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
476
477
478
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
479
480
        if arrival_time is None:
            arrival_time = time.time()
481

482
483
484
485
        if self.tokenizer is not None:
            tokenizer = await self.get_tokenizer_async(lora_request)
            self._validate_token_prompt(prompt, tokenizer=tokenizer)

486
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
487
            prompt,
488
489
            request_id=request_id,
            lora_request=lora_request,
490
491
            prompt_adapter_request=prompt_adapter_request,
        )
492
        processed_inputs = self.input_processor(preprocessed_inputs)
493

494
495
496
497
498
499
500
501
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
502
                tokenizer=await self.get_tokenizer_async(lora_request),
503
504
505
                default_guided_backend=self.decoding_config.
                guided_decoding_backend)

506
        self._add_processed_request(
507
            request_id=request_id,
508
509
510
511
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
512
            prompt_adapter_request=prompt_adapter_request,
513
            trace_headers=trace_headers,
514
            priority=priority,
515
        )
516

517
    async def check_health_async(self) -> None:
518
519
        if self.tokenizer:
            self.tokenizer.check_health()
520
        self.model_executor.check_health()
521

522

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
        default_guided_backend: str) -> SamplingParams:
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
    if (guided_decoding := sampling_params.guided_decoding) is None:
        return sampling_params

    logger.debug("Building guided decoding logits processor. "
                 "Params: %s", guided_decoding)

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
        guided_params=guided_decoding, tokenizer=tokenizer)

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


553
class AsyncLLMEngine(EngineClient):
554
    """An asynchronous wrapper for :class:`LLMEngine`.
555

556
557
558
559
560
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
561
562

    Args:
563
        log_requests: Whether to log the requests.
564
565
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
566
567
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
568
    """
569

Antoni Baum's avatar
Antoni Baum committed
570
571
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

572
573
574
    def __init__(self,
                 *args,
                 log_requests: bool = True,
575
                 start_engine_loop: bool = True,
576
                 **kwargs) -> None:
577
        self.log_requests = log_requests
578
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
579

580
581
582
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
583
584
585
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

586
587
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
588
                weak_bind(self.process_request_outputs)
589

590
        self.background_loop: Optional[asyncio.Future] = None
591
592
593
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
594
        self._background_loop_unshielded: Optional[asyncio.Task] = None
595
        self.start_engine_loop = start_engine_loop
596
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
597

598
599
600
        # Lazy initialized fields
        self._request_tracker: RequestTracker

601
602
603
604
605
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

606
    @classmethod
607
    def _get_executor_cls(
608
            cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
609
610
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
611
612
613
614
615
616
617
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
618
619
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
620
        elif engine_config.device_config.device_type == "tpu":
621
622
623
624
625
626
627
            if distributed_executor_backend == "ray":
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
628
629
630
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
631
632
633
634
635
636
637
638
        elif engine_config.device_config.device_type == "hpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
                executor_class = RayHPUExecutorAsync
            else:
                from vllm.executor.hpu_executor import HPUExecutorAsync
                executor_class = HPUExecutorAsync
639
640
641
642
643
644
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
645
646
647
648
649
650
651
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
652
653
654
655
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
656
657
658
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
659
        elif distributed_executor_backend == "ray":
660
661
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
662
663
664
665
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
666
667
668
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
669
670
671
672
673
674
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
675
        engine_config: Optional[VllmConfig] = None,
676
677
678
679
680
681
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
682
683
        if engine_config is None:
            engine_config = engine_args.create_engine_config()
684
685
686

        executor_class = cls._get_executor_cls(engine_config)

687
688
689
        if executor_class.uses_ray:
            initialize_ray_cluster(engine_config.parallel_config)

690
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
691
        engine = cls(
692
            vllm_config=engine_config,
693
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
694
695
696
697
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
698
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
699
        )
700
701
        return engine

702
703
    @property
    def is_running(self) -> bool:
704
        return (self.background_loop is not None
705
                and self._background_loop_unshielded is not None
706
707
708
709
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
710
711
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
712
713
714
715
716
717
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

718
    @property
719
720
721
722
723
724
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
725

726
727
728
729
730
731
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
732

733
734
735
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

736
737
738
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
739
    ) -> AnyTokenizer:
740
        return await self.engine.get_tokenizer_async(lora_request)
741

742
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
743
        """Start the background loop."""
744
745
746
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
747
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
748
            raise RuntimeError("Background loop is already running.")
749
750
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
751
752

        self._background_loop_unshielded = asyncio.get_event_loop(
753
        ).create_task(self.run_engine_loop(weakref.ref(self)))
754
        self._background_loop_unshielded.add_done_callback(
755
            partial(_log_task_completion, error_callback=self._error_callback))
756
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
757

758
759
760
761
762
763
764
765
766
767
768
769
770
771
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

772
    async def engine_step(self, virtual_engine: int) -> bool:
773
774
775
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
776

777
778
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
779
780
781

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
782
            try:
783
                await self.engine.add_request_async(**new_request)
784
785
786
787
788
789
790
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
791

792
793
        if aborted_requests:
            await self._engine_abort(aborted_requests)
794

795
        request_outputs = await self.engine.step_async(virtual_engine)
796

Antoni Baum's avatar
Antoni Baum committed
797
        # Put the outputs into the corresponding streams.
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
813
        for request_output in request_outputs:
814
            self._request_tracker.process_request_output(
815
                request_output, verbose=self.log_requests)
816
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
817

818
        return all_finished
819

Antoni Baum's avatar
Antoni Baum committed
820
    async def _engine_abort(self, request_ids: Iterable[str]):
821
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
822

823
824
825
826
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
827
        engine: Optional[AsyncLLMEngine] = engine_ref()
828
829
830
        if not engine:
            return

831
        pipeline_parallel_size = \
832
                engine.engine.parallel_config.pipeline_parallel_size
833
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
834
        while True:
835
            if not any(has_requests_in_progress):
836
                logger.debug("Waiting for new requests...")
837
838
839
840
841
842
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
843
844
845
846
847
848
849
850
851
852
853
854
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
855
                logger.debug("Got new requests!")
856
                requests_in_progress = [
857
                    asyncio.create_task(engine.engine_step(ve))
858
859
860
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
861
862
863
864

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
865
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
866
867
868
869
870
871
872
873
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
874
                    has_unfinished_requests = (
875
876
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
877
                            virtual_engine))
878
879
880
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
881
                                engine.engine_step(virtual_engine)))
882
883
884
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
885
886
887
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
888
                engine.set_errored(exc)
889
                raise
Antoni Baum's avatar
Antoni Baum committed
890
891
            await asyncio.sleep(0)

892
893
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
894
895
    @overload  # DEPRECATED
    def add_request(
896
897
        self,
        request_id: str,
898
899
        *,
        inputs: PromptType,
900
        params: Union[SamplingParams, PoolingParams],
901
902
903
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
904
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
905
        priority: int = 0,
906
907
908
909
910
911
912
913
914
915
916
917
918
919
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
920
        priority: int = 0,
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
938
        priority: int = 0,
939
940
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
941
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
942
943
944
945
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

946
        if not self.is_running:
947
948
949
950
951
952
953
954
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
955

956
957
958
959
960
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

961
        stream = self._request_tracker.add_request(
962
            request_id,
963
            verbose=self.log_requests,
964
            prompt=prompt,
965
            params=params,
966
            arrival_time=arrival_time or time.time(),
967
            lora_request=lora_request,
968
            trace_headers=trace_headers,
969
970
971
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
972

973
        return stream.generator()
974

975
    async def generate(
976
        self,
977
        prompt: PromptType,
978
979
        sampling_params: SamplingParams,
        request_id: str,
980
        lora_request: Optional[LoRARequest] = None,
981
        trace_headers: Optional[Mapping[str, str]] = None,
982
983
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
984
    ) -> AsyncGenerator[RequestOutput, None]:
985
986
987
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
988
989
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
990
991

        Args:
992
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
993
                for more details about the format of each input.
994
995
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
996
            lora_request: LoRA request to use for generation, if any.
997
            trace_headers: OpenTelemetry trace headers.
998
            prompt_adapter_request: Prompt Adapter request to use
999
                                            for generation, if any.
1000
1001
            priority: The priority of the request.
                Only applicable with priority scheduling.
1002
1003

        Yields:
1004
1005
            The output `RequestOutput` objects from the LLMEngine
            for the request.
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1023
            >>> # note that engine_args here is AsyncEngineArgs instance
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1050
        """
1051
        async for output in await self.add_request(
1052
                request_id,
1053
                prompt,
1054
                sampling_params,
1055
                lora_request=lora_request,
1056
                trace_headers=trace_headers,
1057
                prompt_adapter_request=prompt_adapter_request,
1058
                priority=priority,
1059
        ):
1060
            yield LLMEngine.validate_output(output, RequestOutput)
1061
1062
1063

    async def encode(
        self,
1064
        prompt: PromptType,
1065
1066
1067
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1068
        trace_headers: Optional[Mapping[str, str]] = None,
1069
        priority: int = 0,
1070
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1071
1072
1073
1074
1075
1076
1077
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1078
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1079
                for more details about the format of each input.
1080
1081
1082
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1083
            trace_headers: OpenTelemetry trace headers.
1084
1085
            priority: The priority of the request.
                Only applicable with priority scheduling.
1086
1087

        Yields:
1088
            The output `EmbeddingRequestOutput` objects from the LLMEngine
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1107
            >>> # note that engine_args here is AsyncEngineArgs instance
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1133
        async for output in await self.add_request(
1134
                request_id,
1135
                prompt,
1136
                pooling_params,
1137
                lora_request=lora_request,
1138
                trace_headers=trace_headers,
1139
                priority=priority,
1140
        ):
1141
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1142

Antoni Baum's avatar
Antoni Baum committed
1143
1144
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1145

Antoni Baum's avatar
Antoni Baum committed
1146
1147
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1148

Antoni Baum's avatar
Antoni Baum committed
1149
1150
1151
        Args:
            request_id: The unique id of the request.
        """
1152
1153
1154
1155
1156
1157
1158
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1159
        return self._abort(request_id)
1160

Antoni Baum's avatar
Antoni Baum committed
1161
    def _abort(self, request_id: str) -> None:
1162
1163
1164
1165
1166
1167
1168
1169
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1170
        self._request_tracker.abort_request(request_id,
1171
                                            exception=asyncio.CancelledError,
1172
                                            verbose=self.log_requests)
1173

1174
1175
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1176
        return self.engine.get_model_config()
1177

1178
1179
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1180
        return self.engine.get_parallel_config()
1181

1182
1183
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1184
        return self.engine.get_decoding_config()
1185

1186
1187
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1188
        return self.engine.get_scheduler_config()
1189
1190
1191

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1192
        return self.engine.get_lora_config()
1193

1194
1195
1196
1197
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1198
        self.engine.do_log_stats()
1199

1200
    async def check_health(self) -> None:
1201
1202
1203
1204
1205
1206
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1207
        await self.engine.check_health_async()
1208
        logger.debug("Health check took %fs", time.perf_counter() - t)
1209
1210

    async def is_tracing_enabled(self) -> bool:
1211
        return self.engine.is_tracing_enabled()
1212
1213

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1214
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1215
1216

    def remove_logger(self, logger_name: str) -> None:
1217
        self.engine.remove_logger(logger_name=logger_name)
1218
1219

    async def start_profile(self) -> None:
1220
1221
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1222
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1223
1224
1225
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1226
1227

    async def stop_profile(self) -> None:
1228
1229
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1230
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1231
1232
1233
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")