"src/diffusers/pipelines/flux/pipeline_flux_inpaint.py" did not exist on "2ee3215949d8f2d3141c2340d8e4d24ec94b2384"
scheduler.py 6.99 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
wangxj's avatar
wangxj committed
2
import functools
xingjinliang's avatar
xingjinliang committed
3
4
5
import time
import typing
from collections import OrderedDict
wangxj's avatar
wangxj committed
6
from typing import Dict, Optional, Type, Union
xingjinliang's avatar
xingjinliang committed
7
8
9

import torch

wangxj's avatar
wangxj committed
10
from megatron.core.inference.async_stream import AsyncStream
xingjinliang's avatar
xingjinliang committed
11
from megatron.core.inference.inference_request import InferenceRequest, Status
wangxj's avatar
wangxj committed
12
from megatron.core.inference.sampling_params import SamplingParams
xingjinliang's avatar
xingjinliang committed
13
14
15
16
17
18
19
20
21
22
23
from megatron.core.inference.utils import Counter


class Scheduler:
    """Scheduler for handling requests to inference engine

    This class is responsible for handing of all the incomign requests

    Args:
        max_batch_size (int): The max batch size that we can pass to the
            inference engine at a time.
wangxj's avatar
wangxj committed
24
        request_type (InferenceRequest): The class to use for instantiating new requests.
xingjinliang's avatar
xingjinliang committed
25
26
    """

wangxj's avatar
wangxj committed
27
    def __init__(self, max_batch_size):
xingjinliang's avatar
xingjinliang committed
28
        self.max_batch_size = max_batch_size
wangxj's avatar
wangxj committed
29
30
31
32
33
        self.requests: Dict[str, InferenceRequest] = OrderedDict()
        self.streams: Dict[str, AsyncStream] = OrderedDict()
        self.active_request_pool: Dict[str, InferenceRequest] = OrderedDict()
        self.waiting_request_pool: Dict[str, InferenceRequest] = OrderedDict()
        self.completed_request_pool: Dict[str, InferenceRequest] = OrderedDict()
xingjinliang's avatar
xingjinliang committed
34
35
        self.request_counter = Counter()

wangxj's avatar
wangxj committed
36
37
38
39
40
    def get_new_request_id(self) -> str:
        """Gets a new request id"""
        request_id = str(next(self.request_counter))
        return request_id

xingjinliang's avatar
xingjinliang committed
41
42
    def add_request(
        self,
wangxj's avatar
wangxj committed
43
44
45
46
47
48
49
50
        prompt: Optional[str] = None,
        prompt_tokens: Optional[torch.Tensor] = None,
        encoder_prompt: Optional[str] = None,
        inference_parameters: Optional[SamplingParams] = None,
        arrival_time: Optional[float] = None,
        streaming: bool = False,
        inference_request: Optional[InferenceRequest] = None,
    ) -> str:
xingjinliang's avatar
xingjinliang committed
51
52
53
54
55
56
57
58
59
        """Add an incoming request

        This method will add the request to either the active pool or the waiting pool
        depending on the batch size.

        Args:
            prompt (str): Input prompt string
            prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
            encoder_prompt (str): Encoder input string
wangxj's avatar
wangxj committed
60
            inference_parameters (SamplingParams): The inference parameters
xingjinliang's avatar
xingjinliang committed
61
            arrival_time (float, optional): The incoming request time. Defaults to None.
wangxj's avatar
wangxj committed
62
63
64
            streaming (bool, optional): Whether to asynchronously stream tokens for this request.
            inference_request (InferenceRequest, optional): A fully constructed request.
                Defaults to None.
xingjinliang's avatar
xingjinliang committed
65

wangxj's avatar
wangxj committed
66
67
68
        Returns:
            The request_id for the new request.
        """
xingjinliang's avatar
xingjinliang committed
69
70
71
72
73
74
        status = (
            Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
            if len(self.active_request_pool) < self.max_batch_size
            else Status.WAITING_IN_QUEUE
        )

wangxj's avatar
wangxj committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        if inference_request is None:
            assert prompt is not None
            assert prompt_tokens is not None

            request_id = self.get_new_request_id()

            if arrival_time is None:
                arrival_time = time.time()

            inference_request = InferenceRequest(
                request_id=request_id,
                prompt=prompt,
                inference_parameters=inference_parameters,
                arrival_time=arrival_time,
                prompt_tokens=prompt_tokens,
                status=status,
                encoder_prompt=encoder_prompt,
            )
        else:
            request_id = inference_request.request_id
            inference_request.status = status
            if inference_request.arrival_time is None:
                inference_request.arrival_time = time.time()

        self.requests[request_id] = inference_request

        if streaming:
            abort_request = functools.partial(self.abort_request, request_id=request_id)
            self.streams[request_id] = AsyncStream(request_id, abort_request)
xingjinliang's avatar
xingjinliang committed
104
105
106
107
108
109

        if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS:
            self.active_request_pool[request_id] = inference_request
        else:
            self.waiting_request_pool[request_id] = inference_request

wangxj's avatar
wangxj committed
110
111
        return request_id

xingjinliang's avatar
xingjinliang committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    def have_requests_pending(self) -> bool:
        """Method to check if there are requests pending

        This method returns False only when there are no active requests or waiting requests.
        """
        num_requests_pending = len(self.active_request_pool) + len(self.waiting_request_pool)
        return num_requests_pending > 0

    def add_earliest_waiting_request_to_active_pool(self):
        """Utility to add the waiting request to active pool

        This method will add the earliest request (FIFO) that is in the waiting request
        pool to the active request pool.
        """
        assert (
            len(self.active_request_pool) < self.max_batch_size
        ), "Active request pool is already full. Cant add any more requests"
        if len(self.waiting_request_pool) > 0:
            (earliest_waiting_request_request_id, earliest_waiting_request) = (
                self.waiting_request_pool.popitem(last=False)
            )
            earliest_waiting_request.status = Status.ACTIVE_BUT_NOT_GENERATING_TOKENS
            self.active_request_pool[earliest_waiting_request_request_id] = earliest_waiting_request

wangxj's avatar
wangxj committed
136
137
138
    def update_requests_pools(
        self, result_dict: Optional[typing.OrderedDict[str, InferenceRequest]] = None
    ):
xingjinliang's avatar
xingjinliang committed
139
140
141
142
143
144
145
146
        """Update request pool status

        This method will full up the active request pool, if it has less than max batch size
        elements from the waiting request pool.
        If provided with a request dict, it will put the completed requests into the completed
        request pool and add waiting request into active pool.

        Args:
wangxj's avatar
wangxj committed
147
            result (typing.OrderedDict[str, InferenceRequest], optional): The result returned
xingjinliang's avatar
xingjinliang committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
                by the engine. A dictionary with keys as the request ids, and values as the
                requests. Defaults to None
        """
        for result_request_id in list(result_dict.keys()):
            active_request = self.active_request_pool[result_request_id]

            # If a request has completed put it into the completed request pool.
            if active_request.status == Status.COMPLETED:
                completed_request = self.active_request_pool.pop(result_request_id)
                self.completed_request_pool[result_request_id] = completed_request

        # If the active request pool is not full, add waiting requests in FIFO order
        while (
            len(self.active_request_pool) < self.max_batch_size
            and len(self.waiting_request_pool) > 0
        ):
            self.add_earliest_waiting_request_to_active_pool()
wangxj's avatar
wangxj committed
165
166
167
168
169
170
171
172
173
174
175

    def abort_request(
        self,
        request_id: str,
        *,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None
    ):
        """Cancels the given request"""
        stream = self.streams.get(request_id, None)
        if stream is not None:
            stream.finish(exception=exception)