async_llm_engine.py 50.4 KB
Newer Older
1
2
import asyncio
import time
3
import weakref
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
7
from weakref import ReferenceType
8

9
10
from typing_extensions import deprecated

11
import vllm.envs as envs
12
13
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
14
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.engine.arg_utils import AsyncEngineArgs
16
from vllm.engine.async_timeout import asyncio_timeout
17
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
18
from vllm.engine.metrics_types import StatLoggerBase
19
from vllm.engine.protocol import EngineClient
20
from vllm.executor.executor_base import ExecutorAsyncBase
21
from vllm.executor.gpu_executor import GPUExecutorAsync
22
from vllm.executor.ray_utils import initialize_ray_cluster
23
from vllm.inputs import PromptType
24
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
28
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
29
from vllm.model_executor.layers.sampler import SamplerOutput
30
from vllm.outputs import PoolingRequestOutput, RequestOutput
31
from vllm.pooling_params import PoolingParams
32
from vllm.prompt_adapter.request import PromptAdapterRequest
33
from vllm.sampling_params import SamplingParams
34
from vllm.sequence import ExecuteModelRequest
35
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
36
from vllm.usage.usage_lib import UsageContext
37
from vllm.utils import deprecate_kwargs, weak_bind
38
39

logger = init_logger(__name__)
40
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
41

Antoni Baum's avatar
Antoni Baum committed
42

43
44
45
46
class AsyncEngineDeadError(RuntimeError):
    pass


47
48
49
50
51
52
53
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
54
55

    exception = None
56
    try:
57
58
59
60
61
62
63
64
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
65
66
67
68
69
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
70
            "Task finished unexpectedly. This should never happen! "
71
            "Please open an issue on Github. See stack trace above for the "
72
            "actual cause.") from e
73
74


75
76
77
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
78
class AsyncStream:
79
    """A stream of RequestOutputs or PoolingRequestOutputs for a request
80
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
81

82
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
83
        self.request_id = request_id
84
        self._cancel = cancel
85
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
86
87
        self._finished = False

88
    def put(self, item: Union[RequestOutput, PoolingRequestOutput,
89
                              Exception]) -> None:
90
91
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
92

93
94
95
96
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
97
98
99
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
100
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
101
102
103
104
105

    @property
    def finished(self) -> bool:
        return self._finished

106
107
    async def generator(
        self
108
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
109
        try:
110
            while True:
111
                result = await self._queue.get()
112
                if self._is_raisable(result):
113
114
115
116
117
118
119
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
120

121
122
123
124
125
126
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
127

128
129
130
131
132
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
133
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
134
135
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
136
        self.new_requests_event = asyncio.Event()
137
138
139
140

    def __contains__(self, item):
        return item in self._request_streams

141
142
    def __len__(self) -> int:
        return len(self._request_streams)
143
144
145
146
147
148
149

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
150
            self.abort_request(request_id, exception=exc)
151
        else:
152
            # NB: tuple() used here because self.abort_request pops the stream
153
            # out of self._request_streams, so we can't iterate on it directly
154
155
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
156
157

    def process_request_output(self,
158
                               request_output: Union[RequestOutput,
159
                                                     PoolingRequestOutput],
160
161
162
163
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
164
        finished = request_output.finished
165

166
167
168
169
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
170
171
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
172
        if stream is not None:
173
            stream.put(request_output)
174
175
176
177
178
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
179

180
181
    def process_exception(self,
                          request_id: str,
182
                          exception: BaseException,
183
184
185
186
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
187
            logger.info("Finished request %s.", request_id)
188
        self.abort_request(request_id, exception=exception)
189

190
191
192
193
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
194
195
196
197
198
199
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

200
201
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
202
203
204
205
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
206
207
208

        self.new_requests_event.set()

209
210
211
        if verbose:
            logger.info("Added request %s.", request_id)

212
213
        return stream

214
215
216
    def abort_request(self,
                      request_id: str,
                      *,
217
218
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
219
                      verbose: bool = False) -> None:
220
221
        """Abort a request during next background loop iteration."""
        if verbose:
222
            logger.info("Aborted request %s.", request_id)
223

224
        self._aborted_requests.put_nowait(request_id)
225

226
227
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
228
            stream.finish(exception=exception)
229

230
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
231
232
        """Get the new requests and finished requests to be
        sent to the engine."""
233
        new_requests: List[Dict] = []
234
235
        finished_requests: Set[str] = set()

236
237
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
238
239
240
241
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
242
243
            request_id = stream.request_id
            if request_id in finished_requests:
244
                # The request has already been aborted.
245
246
247
248
249
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
250
251

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
252

253
    async def wait_for_new_requests(self):
254
255
256
257
258
259
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
260

Antoni Baum's avatar
Antoni Baum committed
261
262
263
264

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

265
266
267
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

268
    async def step_async(
269
        self, virtual_engine: int
270
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
271
272
273
274
275
276
277
278
279
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
280
281
282
283
284
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
285
286
        allow_async_output_proc = cached_outputs.allow_async_output_proc

287
288
        ctx = self.scheduler_contexts[virtual_engine]

289
290
291
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

292
293
294
295
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
296

297
            # Schedule iteration
298
299
300
301
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

302
303
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
304

305
306
307
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

308
            # Maybe switch from async mode to sync mode
309
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
310
                self._process_model_outputs(ctx=ctx)
311

312
313
314
315
316
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
317
318
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
319
320
        else:
            finished_requests_ids = list()
321
322
323

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
324

325
        if not scheduler_outputs.is_empty():
326
327
328
329
330
331
332
333

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

334
335
336
337
338
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
339
                virtual_engine=virtual_engine,
340
341
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
342
343
344
345
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
346
347

            if allow_async_output_proc:
348
349
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
350

351
            # Execute the model.
352
            outputs = await self.model_executor.execute_model_async(
353
                execute_model_req)
354

355
356
357
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
358
                self._update_cached_scheduler_output(virtual_engine, outputs)
359
        else:
360
361
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
362
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
363

364
365
366
367
368
369
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
370
            # Clear the cache if we have finished all the steps
371
372
373
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
374

375
376
377
378
379
380
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

381
382
383
384
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
385
386
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
387

388
            if outputs and allow_async_output_proc:
389
                assert len(
390
                    outputs
391
392
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
393
                    outputs[0], seq_group_metadata_list,
394
                    scheduler_outputs.scheduled_seq_groups)
395
396

            if not allow_async_output_proc:
397
                self._process_model_outputs(ctx=ctx)
398
399

                # Log stats.
400
                self.do_log_stats(scheduler_outputs, outputs)
401
402
403
404
405

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
406
            # Multi-step case
407
            return ctx.request_outputs
408
409
410
411

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
412
                self._process_model_outputs(ctx=ctx)
413
            assert len(ctx.output_queue) == 0
414

415
        return ctx.request_outputs
416

417
418
419
420
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

421
422
423
424
425
426
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

427
428
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
429
    async def add_request_async(
430
431
        self,
        request_id: str,
432
433
        *,
        inputs: PromptType,
434
435
436
437
438
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
439
        priority: int = 0,
440
441
442
443
444
445
446
447
448
449
450
451
452
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
453
        priority: int = 0,
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
470
            priority: int = 0,
471
472
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
473
    ) -> None:
474
        """Async version of :meth:`add_request`."""
475
476
477
478
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

479
480
481
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
482
483
484
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
485
486
        if arrival_time is None:
            arrival_time = time.time()
487

488
489
490
491
        if self.tokenizer is not None:
            tokenizer = await self.get_tokenizer_async(lora_request)
            self._validate_token_prompt(prompt, tokenizer=tokenizer)

492
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
493
            prompt,
494
495
            request_id=request_id,
            lora_request=lora_request,
496
497
            prompt_adapter_request=prompt_adapter_request,
        )
498
        processed_inputs = self.input_processor(preprocessed_inputs)
499

500
501
502
503
504
505
506
507
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
508
                tokenizer=await self.get_tokenizer_async(lora_request),
509
510
511
                default_guided_backend=self.decoding_config.
                guided_decoding_backend)

512
        self._add_processed_request(
513
            request_id=request_id,
514
515
516
517
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
518
            prompt_adapter_request=prompt_adapter_request,
519
            trace_headers=trace_headers,
520
            priority=priority,
521
        )
522

523
    async def check_health_async(self) -> None:
524
525
        if self.tokenizer:
            self.tokenizer.check_health()
526
        self.model_executor.check_health()
527

528

529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
        default_guided_backend: str) -> SamplingParams:
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
    if (guided_decoding := sampling_params.guided_decoding) is None:
        return sampling_params

    logger.debug("Building guided decoding logits processor. "
                 "Params: %s", guided_decoding)

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
        guided_params=guided_decoding, tokenizer=tokenizer)

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


559
class AsyncLLMEngine(EngineClient):
560
    """An asynchronous wrapper for :class:`LLMEngine`.
561

562
563
564
565
566
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
567
568

    Args:
569
        log_requests: Whether to log the requests.
570
571
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
572
573
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
574
    """
575

Antoni Baum's avatar
Antoni Baum committed
576
577
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

578
579
580
    def __init__(self,
                 *args,
                 log_requests: bool = True,
581
                 start_engine_loop: bool = True,
582
                 **kwargs) -> None:
583
        self.log_requests = log_requests
584
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
585

586
587
588
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
589
590
591
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

592
593
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
594
                weak_bind(self.process_request_outputs)
595

596
        self.background_loop: Optional[asyncio.Future] = None
597
598
599
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
600
        self._background_loop_unshielded: Optional[asyncio.Task] = None
601
        self.start_engine_loop = start_engine_loop
602
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
603

604
605
606
        # Lazy initialized fields
        self._request_tracker: RequestTracker

607
608
609
610
611
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

612
    @classmethod
613
    def _get_executor_cls(
614
            cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
615
616
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
617
618
619
620
621
622
623
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
624
625
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
626
        elif engine_config.device_config.device_type == "tpu":
627
628
629
630
631
632
633
            if distributed_executor_backend == "ray":
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
634
635
636
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
637
638
639
640
641
642
643
644
        elif engine_config.device_config.device_type == "hpu":
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
                executor_class = RayHPUExecutorAsync
            else:
                from vllm.executor.hpu_executor import HPUExecutorAsync
                executor_class = HPUExecutorAsync
645
646
647
648
649
650
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
651
652
653
654
655
656
657
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
658
659
660
661
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
662
663
664
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
665
        elif distributed_executor_backend == "ray":
666
667
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
668
669
670
671
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
672
673
674
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
675
676
677
678
679
680
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
681
        engine_config: Optional[VllmConfig] = None,
682
683
684
685
686
687
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
688
        if engine_config is None:
689
            engine_config = engine_args.create_engine_config(usage_context)
690
691
692

        executor_class = cls._get_executor_cls(engine_config)

693
694
695
        if executor_class.uses_ray:
            initialize_ray_cluster(engine_config.parallel_config)

696
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
697
        engine = cls(
698
            vllm_config=engine_config,
699
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
700
701
702
703
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
704
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
705
        )
706
707
        return engine

708
709
    @property
    def is_running(self) -> bool:
710
        return (self.background_loop is not None
711
                and self._background_loop_unshielded is not None
712
713
714
715
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
716
717
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
718
719
720
721
722
723
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

724
    @property
725
726
727
728
729
730
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
731

732
733
734
735
736
737
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
738

739
740
741
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

742
743
744
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
745
    ) -> AnyTokenizer:
746
        return await self.engine.get_tokenizer_async(lora_request)
747

748
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
749
        """Start the background loop."""
750
751
752
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
753
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
754
            raise RuntimeError("Background loop is already running.")
755
756
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
757
758

        self._background_loop_unshielded = asyncio.get_event_loop(
759
        ).create_task(self.run_engine_loop(weakref.ref(self)))
760
        self._background_loop_unshielded.add_done_callback(
761
            partial(_log_task_completion, error_callback=self._error_callback))
762
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
763

764
765
766
767
768
769
770
771
772
773
774
775
776
777
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

778
    async def engine_step(self, virtual_engine: int) -> bool:
779
780
781
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
782

783
784
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
785
786
787

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
788
            try:
789
                await self.engine.add_request_async(**new_request)
790
791
792
793
794
795
796
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
797

798
799
        if aborted_requests:
            await self._engine_abort(aborted_requests)
800

801
        request_outputs = await self.engine.step_async(virtual_engine)
802

Antoni Baum's avatar
Antoni Baum committed
803
        # Put the outputs into the corresponding streams.
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
819
        for request_output in request_outputs:
820
            self._request_tracker.process_request_output(
821
                request_output, verbose=self.log_requests)
822
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
823

824
        return all_finished
825

Antoni Baum's avatar
Antoni Baum committed
826
    async def _engine_abort(self, request_ids: Iterable[str]):
827
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
828

829
830
831
832
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
833
        engine: Optional[AsyncLLMEngine] = engine_ref()
834
835
836
        if not engine:
            return

837
        pipeline_parallel_size = \
838
                engine.engine.parallel_config.pipeline_parallel_size
839
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
840
        while True:
841
            if not any(has_requests_in_progress):
842
                logger.debug("Waiting for new requests...")
843
844
845
846
847
848
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
849
850
851
852
853
854
855
856
857
858
859
860
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
861
                logger.debug("Got new requests!")
862
                requests_in_progress = [
863
                    asyncio.create_task(engine.engine_step(ve))
864
865
866
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
867
868
869
870

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
871
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
872
873
874
875
876
877
878
879
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
880
                    has_unfinished_requests = (
881
882
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
883
                            virtual_engine))
884
885
886
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
887
                                engine.engine_step(virtual_engine)))
888
889
890
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
891
892
893
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
894
                engine.set_errored(exc)
895
                raise
Antoni Baum's avatar
Antoni Baum committed
896
897
            await asyncio.sleep(0)

898
899
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
900
901
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
902
    def add_request(
903
904
        self,
        request_id: str,
905
906
        *,
        inputs: PromptType,
907
        params: Union[SamplingParams, PoolingParams],
908
909
910
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
911
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
912
        priority: int = 0,
913
    ) -> Coroutine[None, None, AsyncGenerator[Union[
914
            RequestOutput, PoolingRequestOutput], None]]:
915
916
917
918
919
920
921
922
923
924
925
926
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
927
        priority: int = 0,
928
    ) -> Coroutine[None, None, AsyncGenerator[Union[
929
            RequestOutput, PoolingRequestOutput], None]]:
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
945
        priority: int = 0,
946
947
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
948
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
949
950
951
952
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

953
        if not self.is_running:
954
955
956
957
958
959
960
961
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
962

963
964
965
966
967
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

968
        stream = self._request_tracker.add_request(
969
            request_id,
970
            verbose=self.log_requests,
971
            prompt=prompt,
972
            params=params,
973
            arrival_time=arrival_time or time.time(),
974
            lora_request=lora_request,
975
            trace_headers=trace_headers,
976
977
978
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
979

980
        return stream.generator()
981

982
    async def generate(
983
        self,
984
        prompt: PromptType,
985
986
        sampling_params: SamplingParams,
        request_id: str,
987
        lora_request: Optional[LoRARequest] = None,
988
        trace_headers: Optional[Mapping[str, str]] = None,
989
990
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
991
    ) -> AsyncGenerator[RequestOutput, None]:
992
993
994
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
995
996
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
997
998

        Args:
999
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1000
                for more details about the format of each input.
1001
1002
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
1003
            lora_request: LoRA request to use for generation, if any.
1004
            trace_headers: OpenTelemetry trace headers.
1005
            prompt_adapter_request: Prompt Adapter request to use
1006
                                            for generation, if any.
1007
1008
            priority: The priority of the request.
                Only applicable with priority scheduling.
1009
1010

        Yields:
1011
1012
            The output `RequestOutput` objects from the LLMEngine
            for the request.
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1030
            >>> # note that engine_args here is AsyncEngineArgs instance
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1057
        """
1058
        async for output in await self.add_request(
1059
                request_id,
1060
                prompt,
1061
                sampling_params,
1062
                lora_request=lora_request,
1063
                trace_headers=trace_headers,
1064
                prompt_adapter_request=prompt_adapter_request,
1065
                priority=priority,
1066
        ):
1067
            yield LLMEngine.validate_output(output, RequestOutput)
1068
1069
1070

    async def encode(
        self,
1071
        prompt: PromptType,
1072
1073
1074
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1075
        trace_headers: Optional[Mapping[str, str]] = None,
1076
        priority: int = 0,
1077
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
1078
1079
1080
1081
1082
1083
1084
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1085
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1086
                for more details about the format of each input.
1087
1088
1089
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1090
            trace_headers: OpenTelemetry trace headers.
1091
1092
            priority: The priority of the request.
                Only applicable with priority scheduling.
1093
1094

        Yields:
1095
            The output `PoolingRequestOutput` objects from the LLMEngine
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1114
            >>> # note that engine_args here is AsyncEngineArgs instance
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1140
        async for output in await self.add_request(
1141
                request_id,
1142
                prompt,
1143
                pooling_params,
1144
                lora_request=lora_request,
1145
                trace_headers=trace_headers,
1146
                priority=priority,
1147
        ):
1148
            yield LLMEngine.validate_output(output, PoolingRequestOutput)
1149

Antoni Baum's avatar
Antoni Baum committed
1150
1151
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1152

Antoni Baum's avatar
Antoni Baum committed
1153
1154
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1155

Antoni Baum's avatar
Antoni Baum committed
1156
1157
1158
        Args:
            request_id: The unique id of the request.
        """
1159
1160
1161
1162
1163
1164
1165
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1166
        return self._abort(request_id)
1167

Antoni Baum's avatar
Antoni Baum committed
1168
    def _abort(self, request_id: str) -> None:
1169
1170
1171
1172
1173
1174
1175
1176
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1177
        self._request_tracker.abort_request(request_id,
1178
                                            exception=asyncio.CancelledError,
1179
                                            verbose=self.log_requests)
1180

1181
1182
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1183
        return self.engine.get_model_config()
1184

1185
1186
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1187
        return self.engine.get_parallel_config()
1188

1189
1190
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1191
        return self.engine.get_decoding_config()
1192

1193
1194
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1195
        return self.engine.get_scheduler_config()
1196
1197
1198

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1199
        return self.engine.get_lora_config()
1200

1201
1202
1203
1204
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1205
        self.engine.do_log_stats()
1206

1207
    async def check_health(self) -> None:
1208
1209
1210
1211
1212
1213
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1214
        await self.engine.check_health_async()
1215
        logger.debug("Health check took %fs", time.perf_counter() - t)
1216
1217

    async def is_tracing_enabled(self) -> bool:
1218
        return self.engine.is_tracing_enabled()
1219
1220

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1221
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1222
1223

    def remove_logger(self, logger_name: str) -> None:
1224
        self.engine.remove_logger(logger_name=logger_name)
1225
1226

    async def start_profile(self) -> None:
1227
1228
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1229
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1230
1231
1232
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1233
1234

    async def stop_profile(self) -> None:
1235
1236
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1237
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1238
1239
1240
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")