"src/llamafactory/model/model_utils/packing.py" did not exist on "6cedac147c989817368a06ac32ea206ed05e7232"
interpreter.py 32.5 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
"""The interpreter that executes SGL programs"""

import asyncio
import contextvars
import copy
import multiprocessing
import queue
import threading
import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional

import tqdm

from sglang.global_config import global_config
from sglang.lang.ir import (
    SglCommitLazy,
    SglConcateAndAppend,
    SglConstantText,
    SglExpr,
    SglExprList,
    SglGen,
    SglImage,
    SglRoleBegin,
    SglRoleEnd,
    SglSelect,
    SglVariable,
    SglVarScopeBegin,
    SglVarScopeEnd,
    SglVideo,
)
from sglang.utils import (
    encode_image_base64,
    encode_video_base64,
    get_exception_traceback,
)


def run_internal(state, program, func_args, func_kwargs, sync):
    try:
        state.ret_value = program.func(state, *func_args, **func_kwargs)
    except Exception as e:
        raise e
    finally:
        state.stream_executor.end()

    if sync:
        state.stream_executor.sync()

    if global_config.verbosity >= 2:
        print(state.text())


def run_program(
    program,
    backend,
    func_args,
    func_kwargs,
    default_sampling_para,
    stream,
    sync=False,
    use_thread=True,
):
    if hasattr(backend, "endpoint"):
        backend = backend.endpoint
    assert backend is not None, "Please specify a backend"
    func_kwargs.update(program.bind_arguments)
    stream_executor = StreamExecutor(
        backend,
        func_kwargs,
        default_sampling_para,
        chat_template=None,
        stream=stream,
        num_api_spec_tokens=program.num_api_spec_tokens,
        use_thread=use_thread,
    )
    state = ProgramState(stream_executor)

    if stream:
        t = threading.Thread(
            target=run_internal, args=(state, program, func_args, func_kwargs, sync)
        )
        t.start()
        return state
    else:
        run_internal(state, program, func_args, func_kwargs, sync)
        return state


def run_program_batch(
    program,
    backend,
    batch_arguments,
    default_sampling_para,
    num_threads,
    progress_bar,
    generator_style=False,
):
    if hasattr(backend, "endpoint"):
        backend = backend.endpoint

    # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
    if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
        cache_program(program, backend)

    # Run all programs
    if num_threads == "auto":
        num_threads = max(96, multiprocessing.cpu_count() * 16)
    num_threads = min(num_threads, len(batch_arguments))

    if generator_style:
        return _run_program_batch_generator(
            program,
            backend,
            batch_arguments,
            default_sampling_para,
            num_threads,
            progress_bar,
        )

    # Original code path when generator_style=False
    if num_threads == 1:
        rets = []
        if progress_bar:
            for arguments in tqdm.tqdm(batch_arguments):
                rets.append(
                    run_program(
                        program,
                        backend,
                        (),
                        arguments,
                        default_sampling_para,
                        False,
                        True,
                    )
                )
        else:
            for arguments in batch_arguments:
                rets.append(
                    run_program(
                        program,
                        backend,
                        (),
                        arguments,
                        default_sampling_para,
                        False,
                        True,
                    )
                )
    else:
        if progress_bar:
            pbar = tqdm.tqdm(total=len(batch_arguments))

        with ThreadPoolExecutor(num_threads) as executor:
            futures = []
            for arguments in batch_arguments:
                futures.append(
                    executor.submit(
                        run_program,
                        program,
                        backend,
                        (),
                        arguments,
                        default_sampling_para,
                        False,
                        True,
                    )
                )
                if progress_bar:
                    futures[-1].add_done_callback(lambda _: pbar.update())

            rets = [f.result() for f in futures]
        rets[-1].sync()

        if progress_bar:
            pbar.close()

    return rets


def _run_program_batch_generator(
    program,
    backend,
    batch_arguments,
    default_sampling_para,
    num_threads,
    progress_bar,
):
    """Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor."""
    if num_threads == 1:
        iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
        for arguments in iterator:
            yield run_program(
                program,
                backend,
                (),
                arguments,
                default_sampling_para,
                False,
                True,
            )
    else:
        pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None

        # Process in chunks to avoid overwhelming ThreadPoolExecutor
        # Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
        # so we will never reach "yield" until all tasks are done
        chunk_size = 200

        with ThreadPoolExecutor(num_threads) as executor:
            for chunk_start in range(0, len(batch_arguments), chunk_size):
                chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
                chunk_futures = []

                # Submit chunk of tasks
                for i in range(chunk_start, chunk_end):
                    future = executor.submit(
                        run_program,
                        program,
                        backend,
                        (),
                        batch_arguments[i],
                        default_sampling_para,
                        False,
                        True,
                    )
                    if pbar:
                        future.add_done_callback(lambda _: pbar.update())
                    chunk_futures.append(future)

                # Yield results from this chunk as they complete
                for future in chunk_futures:
                    yield future.result()

        if pbar:
            pbar.close()


def cache_program(program, backend):
    from sglang.lang.tracer import extract_prefix_by_tracing

    prefix = extract_prefix_by_tracing(program, backend)
    if prefix and len(prefix) > 64:
        backend.cache_prefix(prefix)


class StreamExecutor:
    """A stream executor that executes SGL expressions in a background thread."""

    def __init__(
        self,
        backend,
        arguments,
        default_sampling_para,
        chat_template,
        stream,
        num_api_spec_tokens=None,
        use_thread=True,
    ):
        from sglang.lang.backend.base_backend import BaseBackend

        self.sid = uuid.uuid4().hex
        self.backend: BaseBackend = backend
        self.arguments: Dict[str, Any] = arguments
        self.default_sampling_para = default_sampling_para
        self.stream = stream

        self.variables = {}  # Dict[name: str -> value: str]
        self.variable_event = {}  # Dict[name: str -> event: threading.Event]
        self.meta_info = {}  # Dict[name: str -> info: str]
        self.is_finished = False
        self.error_ = None

        # For completion
        self.text_ = ""  # The full text

        # For chat
        self.messages_ = []  # The messages in the OpenAI API format
        self.chat_template = chat_template or self.backend.get_chat_template()
        self.cur_role = None
        self.cur_role_begin_pos = None

        # For vision
        self.images_ = []
        self.cur_images = []

        # For fork/join
        self.fork_start_text_pos = None

        # For speculative execution
        self.num_api_spec_tokens = num_api_spec_tokens
        self.speculated_text = ""

        # Worker thread
        self.use_thread = use_thread
        if self.use_thread:
            self.queue = queue.Queue()

            def _run_worker_in_context():
                self._thread_worker_func()

            self.worker = threading.Thread(
                target=contextvars.copy_context().run, args=(_run_worker_in_context,)
            )
            self.worker.start()

        # For streaming
        if stream:
            self.stream_text_event = threading.Event()
            self.stream_var_event = {}
        else:
            self.stream_text_event = None
            self.stream_var_event = None

    def submit(self, expr: SglExpr):
        self._init_var_event(expr)

        if self.use_thread:
            self.queue.put(expr)
        else:
            self._execute(expr)

    def sync(self):
        if self.use_thread:
            self.queue.join()

    def get_var(self, name):
        if name in self.variable_event:
            self.variable_event[name].wait()
        return self.variables[name]

    def set_var(self, name, value):
        self.variables[name] = value

    def get_meta_info(self, name, timeout=None):
        if name in self.variable_event:
            got = self.variable_event[name].wait(timeout)
            if not got:
                raise TimeoutError(f"Timeout while waiting for event '{name}'")
        ret = self.meta_info.get(name, None)
        return ret

    def fork(
        self,
        size: int = 1,
        position_ids_offset: Optional[List[int]] = None,
    ):
        if size > 1 and str(self.text_):
            self.submit(SglCommitLazy())

        self.sync()
        size = int(size)

        exes = [
            StreamExecutor(
                self.backend,
                self.arguments,
                self.default_sampling_para,
                self.chat_template,
                self.stream,
            )
            for _ in range(size)
        ]
        for i in range(size):
            exes[i].variables = dict(self.variables)
            exes[i].text_ = str(self.text_)
            exes[i].messages_ = list(self.messages_)
            exes[i].cur_role = self.cur_role
            exes[i].cur_role_begin_pos = self.cur_role_begin_pos
            exes[i].fork_start_text_pos = len(self.text_)
            exes[i].images_ = list(self.images_)

            # TODO(ying): handle API speculative execution

        return exes

    def text(self):
        self.sync()
        return self.text_

    def messages(self):
        self.sync()
        return self.messages_

    def error(self):
        self.sync()
        return self.error_

    def end(self):
        if self.use_thread:
            if self.worker.is_alive():
                self.queue.put(None)
        self.backend.end_program(self)

    def _thread_worker_func(self):
        error = None

        while True:
            expr = self.queue.get()
            if expr is None:
                self.queue.task_done()
                break

            try:
                self._execute(expr)
            except Exception as e:
                warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
                error = e
                break
            self.queue.task_done()
            if self.stream_text_event:
                self.stream_text_event.set()

        # Clean the queue and events
        if error is not None:
            try:
                while True:
                    self.queue.task_done()
                    self.queue.get_nowait()
            except queue.Empty:
                pass
            for name in self.variable_event:
                self.variable_event[name].set()
            if self.stream_var_event:
                for name in self.stream_var_event:
                    self.stream_var_event[name].set()
            self.error_ = error

        if self.stream_text_event:
            self.stream_text_event.set()

        self.is_finished = True

    def _execute(self, other):
        if isinstance(other, str):
            other = SglConstantText(other)

        assert isinstance(other, SglExpr), f"{other}"

        if isinstance(other, SglConstantText):
            self._execute_fill(other.value)
        elif isinstance(other, SglGen):
            self._execute_gen(other)
        elif isinstance(other, SglSelect):
            self._execute_select(other)
        elif isinstance(other, SglExprList):
            for x in other.expr_list:
                self._execute(x)
        elif isinstance(other, SglRoleBegin):
            self._execute_role_begin(other)
        elif isinstance(other, SglRoleEnd):
            self._execute_role_end(other)
        elif isinstance(other, SglImage):
            self._execute_image(other)
        elif isinstance(other, SglVideo):
            self._execute_video(other)
        elif isinstance(other, SglVariable):
            self._execute_variable(other)
        elif isinstance(other, SglVarScopeBegin):
            self._execute_var_scope_begin(other)
        elif isinstance(other, SglVarScopeEnd):
            self._execute_var_scope_end(other)
        elif isinstance(other, SglCommitLazy):
            self._execute_commit_lazy_operations(other)
        elif isinstance(other, SglConcateAndAppend):
            if (
                global_config.enable_parallel_encoding
                and self.backend.support_concate_and_append
            ):
                self._execute_concatenate_and_append_kv_cache(other)
            else:
                self._execute_concatenate_and_append_text(other)
        else:
            raise ValueError(f"Unknown type: {type(other)}")

    def _execute_fill(self, value: str, prefix=False):
        value = str(value)

        if (
            self.cur_role == "assistant"
            and self.num_api_spec_tokens is not None
            and self.backend.is_chat_model
            and not prefix
        ):
            self.backend.spec_fill(value)
            return

        if self.speculated_text.startswith(value):
            self.speculated_text = self.speculated_text[len(value) :]
        else:
            self.speculated_text = ""

        self.text_ += value

    def _execute_image(self, expr: SglImage):
        path = expr.path

        base64_data = encode_image_base64(path)

        self.images_.append((path, base64_data))
        self.cur_images.append((path, base64_data))
        self.text_ += self.chat_template.image_token

    def _execute_video(self, expr: SglVideo):
        path = expr.path
        num_frames = expr.num_frames

        base64_data = encode_video_base64(path, num_frames)

        self.images_.append((path, base64_data))
        self.cur_images.append((path, base64_data))
        self.text_ += self.chat_template.image_token

    def _spec_gen(self, sampling_params):
        stop = sampling_params.stop
        max_new_tokens = sampling_params.max_new_tokens
        meta_info = {}

        def regen():
            nonlocal meta_info

            sampling_params.max_new_tokens = max(
                sampling_params.max_new_tokens, self.num_api_spec_tokens
            )
            sampling_params.stop = None
            self.speculated_text, meta_info = self.backend.generate(
                self, sampling_params=sampling_params
            )

        def find_stop():
            if isinstance(stop, str):
                return self.speculated_text.find(stop)
            elif isinstance(stop, (tuple, list)):
                pos = -1
                for stop_str in stop:
                    stop_pos = self.speculated_text.find(stop_str)
                    if stop_pos != -1 and (pos == -1 or stop_pos < pos):
                        pos = stop_pos
                return pos
            else:
                raise Exception("Wrong type of stop in sampling parameters.")

        if stop is None:
            if len(self.speculated_text) < max_new_tokens:
                regen()
            comp = self.speculated_text[:max_new_tokens]
            self.speculated_text = self.speculated_text[max_new_tokens:]
        elif isinstance(stop, (str, list, tuple)):
            if self.speculated_text == "":
                regen()
            stop_pos = find_stop()
            if stop_pos == -1:
                stop_pos = min(
                    sampling_params.max_new_tokens,
                    len(self.speculated_text),
                )
            comp = self.speculated_text[:stop_pos]
            self.speculated_text = self.speculated_text[stop_pos:]
        else:
            raise ValueError("Wrong type of stop in sampling parameters.")

        return comp, meta_info

    def _execute_gen(self, expr: SglGen):
        sampling_params = self._resolve_sampling_params(expr.sampling_params)
        name = expr.name
        if not self.stream:
            if self.num_api_spec_tokens is None:
                comp, meta_info = self.backend.generate(
                    self,
                    sampling_params=sampling_params,
                )

            else:
                if self.backend.is_chat_model:
                    # Speculative execution on models with only chat interface.
                    # Store the calls into a temporary list.
                    # They will be lazily executed later.
                    comp, meta_info = self.backend.generate(
                        self,
                        sampling_params=sampling_params,
                        spec_var_name=name,
                    )
                    return

                else:  # Speculative execution on models with completion interface
                    comp, meta_info = self._spec_gen(sampling_params)
            if isinstance(comp, list):
                self.text_ += comp[0]
            else:
                assert isinstance(comp, str)
                self.text_ += comp

            self.variables[name] = comp
            self.meta_info[name] = meta_info
            self.variable_event[name].set()
        else:
            assert (
                self.num_api_spec_tokens is None
            ), "stream is not supported with api speculative execution"
            generator = self.backend.generate_stream(
                self, sampling_params=sampling_params
            )

            self.variables[name] = ""
            self.stream_var_event[name].set()

            for comp, meta_info in generator:
                self.text_ += comp
                self.variables[name] += comp
                self.meta_info[name] = meta_info
                self.stream_var_event[name].set()
                self.stream_text_event.set()

            self.variable_event[name].set()
            self.stream_var_event[name].set()

    def _execute_select(self, expr: SglSelect):
        choices_decision = self.backend.select(
            self, expr.choices, expr.temperature, expr.choices_method
        )
        if expr.name is not None:
            name = expr.name
            self.variables[name] = choices_decision.decision
            self.meta_info[name] = choices_decision.meta_info
            self.variable_event[name].set()
            if self.stream_var_event:
                self.stream_var_event[name].set()
        self.text_ += choices_decision.decision

    def _execute_variable(self, expr: SglVariable):
        src_executor = expr.source_stream_executor
        value = src_executor.get_var(expr.name)
        self._execute_fill(value)

    def _execute_role_begin(self, expr: SglRoleBegin):
        assert self.cur_role is None, "Nested roles are not allowed."

        if len(self.messages_) == 0 and expr.role != "system":
            # Insert the default system message
            default_system = self.chat_template.default_system_prompt
            if default_system:
                self._execute_role_begin(SglRoleBegin("system"))
                self._execute_fill(default_system)
                self._execute_role_end(SglRoleEnd("system"))

        self.cur_role = expr.role

        prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)

        self._execute_fill(prefix, prefix=True)
        self.cur_role_begin_pos = len(self.text_)

    def _execute_role_end(self, expr: SglRoleEnd):
        if (
            self.cur_role == "assistant"
            and self.num_api_spec_tokens is not None
            and self.backend.is_chat_model
        ):
            # Execute the stored lazy generation calls
            self.backend.role_end_generate(self)
        self.cur_role = None

        new_text = self.text_[self.cur_role_begin_pos :].lstrip()

        _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
        self._execute_fill(suffix)

        if self.cur_images:
            # OpenAI vision API format
            last_msg = {
                "role": expr.role,
                "content": [{"type": "text", "text": new_text}],
            }
            for image_path, image_base64_data in self.cur_images:
                last_msg["content"].append(
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{image_base64_data}"
                        },
                    }
                )
            self.messages_.append(last_msg)
            self.cur_images = []
        else:
            # OpenAI chat API format
            self.messages_.append({"role": expr.role, "content": new_text})

    def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
        self.variables[expr.name] = int(len(self.text_))

    def _execute_var_scope_end(self, expr: SglVarScopeEnd):
        self.variables[expr.name] = self.text_[self.variables[expr.name] :]
        self.variable_event[expr.name].set()

    def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
        self.backend.commit_lazy_operations(self)

    def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
        new_text = ""
        for s in expr.states:
            exe = s.stream_executor
            exe.sync()
            new_text += exe.text_[exe.fork_start_text_pos :]

        self._execute_fill(new_text)

    def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
        self_len = len(self.text_)

        for i, s in enumerate(expr.states):
            exe = s.stream_executor
            exe.submit(SglCommitLazy())

        for i, s in enumerate(expr.states):
            exe = s.stream_executor
            exe.sync()
            assert exe.fork_start_text_pos == self_len
            self.text_ += exe.text_[exe.fork_start_text_pos :]

        src_rids = [state.stream_executor.sid for state in expr.states]
        self.backend.concatenate_and_append(src_rids, self.sid)

    def _init_var_event(self, expr):
        if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
            self.variable_event[expr.name] = threading.Event()
            if self.stream:
                self.stream_var_event[expr.name] = threading.Event()
        elif isinstance(expr, SglExprList):
            for e in expr.expr_list:
                self._init_var_event(e)

    def _resolve_sampling_params(self, sampling_params):
        """
        Construct sampling param based on default + override values

        The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
        , and `sampling_params` contains the override values from sgl.gen().

        Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
        It also extends the stop tokens based on the chat template.
        """

        # deepcopy is required because the dict has lists inside
        clone = copy.deepcopy(self.default_sampling_para)

        for item in [
            "max_new_tokens",
            "min_new_tokens",
            "n",
            "stop",
            "stop_token_ids",
            "temperature",
            "top_p",
            "top_k",
            "min_p",
            "frequency_penalty",
            "presence_penalty",
            "ignore_eos",
            "return_logprob",
            "logprob_start_len",
            "top_logprobs_num",
            "return_text_in_logprobs",
            "dtype",
            "regex",
            "json_schema",
        ]:
            value = getattr(sampling_params, item, None)
            if value is not None:
                setattr(clone, item, value)

        if self.chat_template.stop_str:
            if clone.stop == ():
                clone.stop = []
            elif isinstance(clone.stop, str):
                clone.stop = [clone.stop]
            clone.stop += self.chat_template.stop_str

        return clone

    def __del__(self):
        self.end()


class ProgramState:
    """The state of an SGL program."""

    def __init__(self, stream_executor: StreamExecutor):
        self.stream_executor = stream_executor

    def _role_common(self, name: str, expr: Optional[SglExpr] = None):
        if expr is not None:
            role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
            self.stream_executor.submit(role_expr)
            return role_expr
        else:

            @contextmanager
            def role_scope():
                self.stream_executor.submit(SglRoleBegin(name))
                yield
                self.stream_executor.submit(SglRoleEnd(name))

            return role_scope()

    def system(self, expr: Optional[SglExpr] = None):
        return self._role_common("system", expr)

    def user(self, expr: Optional[SglExpr] = None):
        return self._role_common("user", expr)

    def assistant(self, expr: Optional[SglExpr] = None):
        return self._role_common("assistant", expr)

    @contextmanager
    def var_scope(self, name: str):
        self.stream_executor.submit(SglVarScopeBegin(name))
        yield
        self.stream_executor.submit(SglVarScopeEnd(name))

    def fork(
        self,
        size: int = 1,
        position_ids_offset: Optional[List[int]] = None,
    ):
        stream_executors = self.stream_executor.fork(size, position_ids_offset)
        states = [ProgramState(x) for x in stream_executors]
        state_group = ProgramStateGroup(states, self)
        return state_group

    @contextmanager
    def copy(self, position_ids_offset: Optional[List[int]] = None):
        state_group = self.fork(1, position_ids_offset)
        try:
            yield state_group[0]
        finally:
            state_group.join()

    def text(self):
        return self.stream_executor.text()

    def messages(self):
        return self.stream_executor.messages()

    def sync(self):
        return self.stream_executor.sync()

    def error(self):
        return self.stream_executor.error()

    def text_iter(self, var_name: Optional[str] = None):
        if self.stream_executor.stream:
            prev = 0
            if var_name is None:
                event = self.stream_executor.stream_text_event
                while True:
                    event.wait()
                    event.clear()
                    out = str(self.stream_executor.text_[prev:])
                    prev += len(out)
                    if out:
                        yield out
                    if self.stream_executor.is_finished:
                        break
            else:
                event = None
                while not event:
                    if var_name in self.stream_executor.stream_var_event:
                        event = self.stream_executor.stream_var_event[var_name]
                    if self.stream_executor.is_finished:
                        yield ""
                        return

                while True:
                    event.wait()
                    event.clear()
                    out = str(self.stream_executor.variables[var_name][prev:])
                    prev += len(out)
                    if out:
                        yield out
                    if self.stream_executor.variable_event[var_name].is_set():
                        break
        else:
            if var_name is None:
                yield self.text()
            else:
                yield self.get_var(var_name)

    async def text_async_iter(
        self, var_name: Optional[str] = None, return_meta_data: bool = False
    ):
        loop = asyncio.get_running_loop()

        if self.stream_executor.stream:
            prev = 0
            if var_name is None:
                event = self.stream_executor.stream_text_event
                while True:
                    await loop.run_in_executor(None, event.wait)
                    event.clear()
                    out = str(self.stream_executor.text_[prev:])
                    prev += len(out)
                    if out:
                        yield out
                    if self.stream_executor.is_finished:
                        break
            else:
                event = None
                while not event:
                    if var_name in self.stream_executor.stream_var_event:
                        event = self.stream_executor.stream_var_event[var_name]
                    if self.stream_executor.is_finished:
                        yield ""
                        return

                while True:
                    await loop.run_in_executor(None, event.wait)
                    event.clear()
                    out = str(self.stream_executor.variables[var_name][prev:])
                    prev += len(out)
                    if out:
                        if return_meta_data:
                            yield out, self.stream_executor.meta_info[var_name]
                        else:
                            yield out
                    if self.stream_executor.variable_event[var_name].is_set():
                        break
        else:
            if var_name is None:
                yield self.text()
            else:
                yield self.get_var(var_name)

    def get_var(self, name):
        return self.stream_executor.get_var(name)

    def set_var(self, name, value):
        return self.stream_executor.set_var(name, value)

    def get_meta_info(self, name):
        return self.stream_executor.get_meta_info(name)

    def __iadd__(self, other):
        if other is None:
            raise ValueError("Tried to append None to state.")
        self.stream_executor.submit(other)
        return self

    def __getitem__(self, name):
        return self.get_var(name)

    def __setitem__(self, name, value):
        self.set_var(name, value)

    def __contains__(self, name):
        return name in self.stream_executor.variables

    def __del__(self):
        self.stream_executor.end()

    def __repr__(self) -> str:
        return f"ProgramState({self.text()})"


class ProgramStateGroup:
    def __init__(
        self, states: List[ProgramState], src_state: Optional[ProgramState] = None
    ):
        self.states = states
        self.src_state = src_state

    def join(self, mode: str = "gather_variable"):
        if mode == "gather_variable":
            # Copy variables back
            src_vars = self.src_state.stream_executor.variables
            src_var_set = set(src_vars.keys())
            for child_state in self.states:
                child_state.stream_executor.sync()
                child_vars = child_state.stream_executor.variables
                new_vars = set(child_vars.keys()) - src_var_set

                for k in new_vars:
                    if k in src_vars:
                        src_vars[k].append(child_vars[k])
                    else:
                        src_vars[k] = [child_vars[k]]
        elif mode == "concate_and_append":
            # Concatenate and append KV cache
            self.src_state += SglConcateAndAppend(self.states)
            # Need a sync here. Otherwise, `states` can be deleted.
            self.src_state.stream_executor.sync()
        else:
            raise ValueError(f"Invalid join mode: {mode}")

        for s in self.states:
            s.stream_executor.end()

    def __getitem__(self, i: int):
        return self.states[i]

    def __setitem__(self, i: int, value):
        assert self.states[i] == value

    def __iadd__(self, other):
        if isinstance(other, Callable):
            # lambda function
            for i in range(len(self.states)):
                self.states[i] += other(i)
        elif isinstance(other, SglExpr):
            for i in range(len(self.states)):
                self.states[i] += other
        elif isinstance(other, (list, tuple)):
            for i in range(len(self.states)):
                self.states[i] += other[i]
        else:
            raise ValueError(f"Invalid value: {other}")

        return self