tokenizer_manager.py 21.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""TokenizerManager is a process that tokenizes the text."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
import asyncio
19
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import dataclasses
21
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
24
import signal
import sys
25
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26

27
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

33
34
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
36
37
38
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from sglang.srt.managers.io_struct import (
40
    AbortReq,
41
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    BatchStrOut,
43
    BatchTokenIDOut,
44
    EmbeddingReqInput,
45
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    GenerateReqInput,
47
48
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
49
    ProfileReq,
50
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    TokenizedGenerateReqInput,
52
53
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
)
55
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
56
from sglang.srt.server_args import PortArgs, ServerArgs
57
from sglang.srt.utils import get_zmq_socket, kill_child_process
Lianmin Zheng's avatar
Lianmin Zheng committed
58
59
60

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

61
62
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
63
64
65

@dataclasses.dataclass
class ReqState:
66
67
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
70
71
72
73
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
74
75
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78
79
80
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
81
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
82
83
        self.server_args = server_args

84
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        context = zmq.asyncio.Context(2)
86
87
88
89
90
91
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name
        )
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
92

93
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
94
        self.model_path = server_args.model_path
95
        self.served_model_name = server_args.served_model_name
96
97
        self.model_config = ModelConfig(
            server_args.model_path,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
98
            trust_remote_code=server_args.trust_remote_code,
99
100
101
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
102
        )
103

104
105
106
        self.is_generation = self.model_config.is_generation
        self.context_len = self.model_config.context_len

107
108
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
109

110
        # Create tokenizer
111
112
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
113
        else:
114
            if self.model_config.is_multimodal:
115
116
117
118
119
120
121
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
122

123
124
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
125
                    self.model_config.hf_config, server_args, self.processor
126
127
128
129
130
131
132
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
133

134
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
135
        self.to_create_loop = True
136
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
137

138
        # For update model weights
139
140
141
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

142
143
144
        # Others
        self.gracefully_exit = False

145
    async def generate_request(
146
        self,
147
        obj: Union[GenerateReqInput, EmbeddingReqInput],
148
        request: Optional[fastapi.Request] = None,
149
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
150
        if self.to_create_loop:
151
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
152

153
        while self.model_update_lock.locked():
154
            await asyncio.sleep(0.001)
155

156
157
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
158
159
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
160
161
            )

162
        obj.normalize_batch_and_arguments()
163
        is_single = obj.is_single
164
        if is_single:
165
166
167
            tokenized_obj = await self._tokenize_one_request(obj)
            self.send_to_scheduler.send_pyobj(tokenized_obj)
            async for response in self._wait_one_response(obj, request):
168
169
170
171
                yield response
        else:
            async for response in self._handle_batch_request(obj, request):
                yield response
172

173
    async def _tokenize_one_request(
174
        self,
175
        obj: Union[GenerateReqInput, EmbeddingReqInput],
176
    ):
177
178
179
180
181
182
183
184
185
        """Tokenize one request."""
        # Tokenize
        input_text = obj.text
        if obj.input_ids is None:
            input_ids = self.tokenizer.encode(input_text)
        else:
            input_ids = obj.input_ids

        if self.is_generation:
186
            image_inputs = await self.image_processor.process_images_async(
187
                obj.image_data, input_text or input_ids, obj
188
            )
189
190
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
191
192
193
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
Lianmin Zheng's avatar
Lianmin Zheng committed
194

195
196
197
198
199
200
201
202
203
204
205
206
207
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
208
            tokenized_obj = TokenizedGenerateReqInput(
209
                obj.rid,
210
211
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
212
                image_inputs,
213
214
215
216
217
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
Chayenne's avatar
Chayenne committed
218
                obj.lora_path,
219
            )
220
        elif isinstance(obj, EmbeddingReqInput):
221
            tokenized_obj = TokenizedEmbeddingReqInput(
222
                obj.rid,
223
224
225
226
                input_text,
                input_ids,
                sampling_params,
            )
227

228
        return tokenized_obj
229

230
    async def _wait_one_response(
231
        self,
232
        obj: Union[GenerateReqInput, EmbeddingReqInput],
233
234
        request: Optional[fastapi.Request] = None,
    ):
235
        """Wait for the response of one request."""
236
237
        event = asyncio.Event()
        state = ReqState([], False, event)
238
        self.rid_to_state[obj.rid] = state
239

240
241
        while True:
            try:
242
                await asyncio.wait_for(state.event.wait(), timeout=4)
243
244
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
245
246
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
247
248
                continue

249
            if isinstance(obj, GenerateReqInput):
250
251
                out = self.convert_logprob_style(
                    state.out_list[-1],
252
253
                    obj.return_logprob,
                    obj.top_logprobs_num,
254
255
                    obj.return_text_in_logprobs,
                )
256
            else:  # isinstance(obj, (EmbeddingReqInput,))
257
                out = state.out_list[-1]
258
259
260

            state.out_list = []
            if state.finished:
261
                if self.server_args.log_requests:
262
                    # Log requests
263
                    logger.info(f"in={obj}, out={out}")
264
                del self.rid_to_state[obj.rid]
265
266
267
                yield out
                break

268
            state.event.clear()
269
270
            yield out

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
                self.send_to_scheduler.send_pyobj(tokenized_obj)
                generators.append(self._wait_one_response(tmp_obj, request))
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
Chayenne's avatar
Chayenne committed
293
294
295
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self.send_to_scheduler.send_pyobj(tokenized_obj)
                await self._wait_one_response(tmp_obj, request).__anext__()

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self.send_to_scheduler.send_pyobj(tokenized_obj)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
Chayenne's avatar
Chayenne committed
327
328
329
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )
330
331
332
333
334
335
336
337
338
339
340
341

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass

342
343
    def flush_cache(self):
        req = FlushCacheReq()
344
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
345

346
347
348
349
350
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
351
        self.send_to_scheduler.send_pyobj(req)
352

353
354
355
356
357
358
359
360
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

361
362
363
364
365
    async def get_memory_pool_size(self):
        if self.to_create_loop:
            self.create_handle_loop()

        req = GetMemPoolSizeReq()
Byron Hsu's avatar
Byron Hsu committed
366
367
368

        self.send_to_scheduler.send_pyobj(req)
        self.mem_pool_size = asyncio.Future()
369

370
        # FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
371
372
        if self.server_args.dp_size == 1:
            res = await self.mem_pool_size
Byron Hsu's avatar
Byron Hsu committed
373
            return res.size
Chayenne's avatar
Chayenne committed
374
        else:  # self.server_args.dp_size > 1
375
376
377
            self.mem_pool_size_tmp = []
            res = await self.mem_pool_size
            ret = [r.size for r in res]
Byron Hsu's avatar
Byron Hsu committed
378
            return ret
379

380
381
382
    async def update_weights(
        self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
    ):
383
384
385
386
387
388
389
390
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
391

Byron Hsu's avatar
Byron Hsu committed
392
393
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
Lianmin Zheng's avatar
Lianmin Zheng committed
394
395
396
397
398
399
                for i in range(3):
                    while len(self.rid_to_state) > 0:
                        await asyncio.sleep(0.001)
                    # FIXME: We add some sleep here to avoid some race conditions.
                    # We can use a read-write lock as a better fix.
                    await asyncio.sleep(0.01)
Byron Hsu's avatar
Byron Hsu committed
400
401
402
403
                self.send_to_scheduler.send_pyobj(obj)
                self.model_update_result = asyncio.Future()

                if self.server_args.dp_size == 1:
404
405
406
407
408
                    result = await self.model_update_result
                    if result.success:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
Byron Hsu's avatar
Byron Hsu committed
409
                    return result.success, result.message
Chayenne's avatar
Chayenne committed
410
                else:  # self.server_args.dp_size > 1
411
412
413
414
415
416
417
418
419
420
                    self.model_update_tmp = []
                    result = await self.model_update_result

                    all_success = all([r.success for r in result])
                    if all_success is True:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
                    all_message = [r.message for r in result]
                    all_message = " | ".join(all_message)
Byron Hsu's avatar
Byron Hsu committed
421
                    return all_success, all_message
422

423
424
425
        else:
            return False, "Another update is in progress. Please try again later."

Lianmin Zheng's avatar
Lianmin Zheng committed
426
    def create_abort_task(self, obj: GenerateReqInput):
427
428
        # Abort the request if the client is disconnected.
        async def abort_request():
429
            await asyncio.sleep(1)
430
431
432
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
433
                for rid in obj.rid:
434
435
436
437
438
439
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

440
    def create_handle_loop(self):
441
442
443
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
444
445
446
447
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

448
449
450
451
452
453
454
455
456
457
458
459
        signal_handler = SignalHandler(self)
        loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        loop.create_task(self.sigterm_watchdog())

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
            await asyncio.sleep(60)

        # drain requests
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
460
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
461
462
463
464
465
466
467
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

        kill_child_process(include_self=True)
468
        sys.exit(0)
469

Lianmin Zheng's avatar
Lianmin Zheng committed
470
    async def handle_loop(self):
471
472
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
473
        while True:
474
475
476
477
478
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
479
480
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
481
                else:  # self.server_args.dp_size > 1
482
483
484
485
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
486
                continue
487
            elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
488
489
                if self.server_args.dp_size == 1:
                    self.mem_pool_size.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
490
                else:  # self.sever_args.dp_size > 1
491
492
493
494
                    self.mem_pool_size_tmp.append(recv_obj)
                    # set future if the all results are received
                    if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
                        self.mem_pool_size.set_result(self.mem_pool_size_tmp)
495
                continue
496

497
498
499
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
500

501
502
503
504
505
506
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
507
508
509
510
511
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
512
513
                elif isinstance(recv_obj, BatchTokenIDOut):
                    out_dict = {
514
                        "token_ids": recv_obj.output_ids[i],
515
516
                        "meta_info": recv_obj.meta_info[i],
                    }
517
518
519
520
521
522
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
523
524
525
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
526

Liangsheng Yin's avatar
Liangsheng Yin committed
527
    def convert_logprob_style(
528
529
530
531
532
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
533
    ):
534
        if return_logprob:
535
536
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
537
            )
538
539
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
540
            )
541
542

            if top_logprobs_num > 0:
543
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
544
                    self.detokenize_top_logprobs_tokens(
545
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
546
547
                        return_text_in_logprobs,
                    )
548
                )
549
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
550
                    self.detokenize_top_logprobs_tokens(
551
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
552
                    )
553
                )
554
555
        return ret

556
557
558
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
559
        # TODO(lianmin): This should run on DetokenizerManager
560
561
562
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

563
        assert self.tokenizer is not None
564
565
566
567
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
568
            for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
569
570
        ]

571
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
572
573
574
575
576
577
578
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
579
        return top_logprobs
580
581
582
583
584
585
586
587
588
589
590


class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True