ir.py 16.9 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
"""The intermediate representation."""

import dataclasses
import inspect
5
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
from typing import List, Optional, Union

from sglang.global_config import global_config
9
from sglang.lang.choices import ChoicesSamplingMethod
Lianmin Zheng's avatar
Lianmin Zheng committed
10

11
12
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
13
REGEX_BOOL = r"(True|False)"
14
REGEX_STR = r"\"[\w\d\s]*\""  # bugs with regex r"\".*\"" in interegular pkg
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17

@dataclasses.dataclass
18
class SglSamplingParams:
19
    max_new_tokens: int = 128
Lianmin Zheng's avatar
Lianmin Zheng committed
20
    stop: Union[str, List[str]] = ()
21
    stop_token_ids: Optional[List[int]] = ()
Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25
26
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = -1  # -1 means disable
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
27
    ignore_eos: bool = False
28
    return_logprob: Optional[bool] = None
29
30
31
    logprob_start_len: Optional[int] = (None,)
    top_logprobs_num: Optional[int] = (None,)
    return_text_in_logprobs: Optional[bool] = (None,)
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33
34
35
36
37

    # for constrained generation, not included in to_xxx_kwargs
    dtype: Optional[str] = None
    regex: Optional[str] = None

    def clone(self):
38
        return SglSamplingParams(
Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
            self.max_new_tokens,
            self.stop,
41
            self.stop_token_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
45
46
            self.temperature,
            self.top_p,
            self.top_k,
            self.frequency_penalty,
            self.presence_penalty,
47
48
49
50
51
            self.ignore_eos,
            self.return_logprob,
            self.logprob_start_len,
            self.top_logprobs_num,
            self.return_text_in_logprobs,
Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
54
55
        )

    def to_openai_kwargs(self):
        # OpenAI does not support top_k, so we drop it here
56
57
        if self.regex is not None:
            warnings.warn("Regular expression is not supported in the OpenAI backend.")
Lianmin Zheng's avatar
Lianmin Zheng committed
58
59
60
61
62
63
64
65
66
        return {
            "max_tokens": self.max_new_tokens,
            "stop": self.stop or None,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "frequency_penalty": self.frequency_penalty,
            "presence_penalty": self.presence_penalty,
        }

67
68
    def to_vertexai_kwargs(self):
        if self.regex is not None:
69
70
71
            warnings.warn(
                "Regular expression is not supported in the VertexAI backend."
            )
shiyi.c_98's avatar
shiyi.c_98 committed
72
73
74
75
76
77
78
79
80
        return {
            "candidate_count": 1,
            "max_output_tokens": self.max_new_tokens,
            "stop_sequences": self.stop,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k if self.top_k > 0 else None,
        }

Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
    def to_anthropic_kwargs(self):
        # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
83
        if self.regex is not None:
84
85
86
            warnings.warn(
                "Regular expression is not supported in the Anthropic backend."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
87
        return {
88
            "max_tokens": self.max_new_tokens,
89
90
91
            "stop_sequences": (
                self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
94
95
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
        }
96

胡译文's avatar
胡译文 committed
97
98
    def to_litellm_kwargs(self):
        if self.regex is not None:
99
            warnings.warn("Regular expression is not supported in the LiteLLM backend.")
胡译文's avatar
胡译文 committed
100
101
102
103
104
105
106
107
        return {
            "max_tokens": self.max_new_tokens,
            "stop": self.stop or None,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "frequency_penalty": self.frequency_penalty,
            "presence_penalty": self.presence_penalty,
        }
Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
110
111
112

    def to_srt_kwargs(self):
        return {
            "max_new_tokens": self.max_new_tokens,
            "stop": self.stop,
113
            "stop_token_ids": self.stop_token_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
116
117
118
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "frequency_penalty": self.frequency_penalty,
            "presence_penalty": self.presence_penalty,
119
            "ignore_eos": self.ignore_eos,
Lianmin Zheng's avatar
Lianmin Zheng committed
120
121
122
123
124
            "regex": self.regex,
        }


class SglFunction:
125
    def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
126
        self.func = func
127
        self.num_api_spec_tokens = num_api_spec_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
130
131
132
133
134
        self.bind_arguments = bind_arguments or {}
        self.pin_prefix_rid = None

        # Parse arguments
        argspec = inspect.getfullargspec(func)
        assert argspec.args[0] == "s", 'The first argument must be "s"'
        self.arg_names = argspec.args[1:]
135
        self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
Lianmin Zheng's avatar
Lianmin Zheng committed
136
137
138
139
140
141
142
143
144
145

    def bind(self, **kwargs):
        assert all(key in self.arg_names for key in kwargs)

        new_bind_dict = {**self.bind_arguments, **kwargs}
        return SglFunction(self.func, bind_arguments=new_bind_dict)

    def run(
        self,
        *args,
146
        max_new_tokens: int = 128,
147
148
        stop: Union[str, List[str]] = [],
        stop_token_ids: Optional[List[int]] = [],
Lianmin Zheng's avatar
Lianmin Zheng committed
149
150
151
152
153
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
154
        ignore_eos: bool = False,
155
156
157
158
        return_logprob: Optional[bool] = None,
        logprob_start_len: Optional[int] = None,
        top_logprobs_num: Optional[int] = None,
        return_text_in_logprobs: Optional[bool] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
159
160
161
162
163
164
        stream: bool = False,
        backend=None,
        **kwargs,
    ):
        from sglang.lang.interpreter import run_program

165
        default_sampling_para = SglSamplingParams(
Lianmin Zheng's avatar
Lianmin Zheng committed
166
167
            max_new_tokens=max_new_tokens,
            stop=stop,
168
            stop_token_ids=stop_token_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
169
170
171
172
173
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
174
            ignore_eos=ignore_eos,
175
176
177
178
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            return_text_in_logprobs=return_text_in_logprobs,
Lianmin Zheng's avatar
Lianmin Zheng committed
179
180
181
182
183
184
185
186
        )
        backend = backend or global_config.default_backend
        return run_program(self, backend, args, kwargs, default_sampling_para, stream)

    def run_batch(
        self,
        batch_kwargs,
        *,
187
        max_new_tokens: int = 128,
Lianmin Zheng's avatar
Lianmin Zheng committed
188
        stop: Union[str, List[str]] = (),
189
        stop_token_ids: Optional[List[int]] = [],
Lianmin Zheng's avatar
Lianmin Zheng committed
190
191
192
193
194
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
195
        ignore_eos: bool = False,
196
197
198
199
        return_logprob: Optional[bool] = None,
        logprob_start_len: Optional[int] = None,
        top_logprobs_num: Optional[int] = None,
        return_text_in_logprobs: Optional[bool] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
200
201
202
203
204
205
206
207
208
        backend=None,
        num_threads: Union[str, int] = "auto",
        progress_bar: bool = False,
    ):
        from sglang.lang.interpreter import run_program_batch

        assert isinstance(batch_kwargs, (list, tuple))
        if len(batch_kwargs) == 0:
            return []
209
210
211
212
213
214
        if not isinstance(batch_kwargs[0], dict):
            num_programs = len(batch_kwargs)
            # change the list of argument values to dict of arg_name -> arg_value
            batch_kwargs = [
                {self.arg_names[i]: v for i, v in enumerate(arg_values)}
                for arg_values in batch_kwargs
Ying Sheng's avatar
Ying Sheng committed
215
216
217
218
                if isinstance(arg_values, (list, tuple))
                and len(self.arg_names) - len(self.arg_defaults)
                <= len(arg_values)
                <= len(self.arg_names)
219
220
221
222
            ]
            # Ensure to raise an exception if the number of arguments mismatch
            if len(batch_kwargs) != num_programs:
                raise Exception("Given arguments mismatch the SGL function signature")
Lianmin Zheng's avatar
Lianmin Zheng committed
223

224
        default_sampling_para = SglSamplingParams(
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226
            max_new_tokens=max_new_tokens,
            stop=stop,
227
            stop_token_ids=stop_token_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
228
229
230
231
232
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
233
            ignore_eos=ignore_eos,
234
235
236
237
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            return_text_in_logprobs=return_text_in_logprobs,
Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        )
        backend = backend or global_config.default_backend
        return run_program_batch(
            self,
            backend,
            batch_kwargs,
            default_sampling_para,
            num_threads,
            progress_bar,
        )

    def trace(self, *, backend=None, **kwargs):
        from sglang.lang.tracer import trace_program

        backend = backend or global_config.default_backend
        return trace_program(self, kwargs, backend)

255
256
    def cache(self, backend=None):
        from sglang.lang.interpreter import cache_program
Lianmin Zheng's avatar
Lianmin Zheng committed
257
258

        backend = backend or global_config.default_backend
259
        return cache_program(self, backend)
Lianmin Zheng's avatar
Lianmin Zheng committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

    def compile(self, *, backend=None):
        from sglang.lang.compiler import compile_func

        return compile_func(self, backend)

    def __call__(self, *args, **kwargs):
        from sglang.lang.tracer import TracingScope

        tracing_scope = TracingScope.get_current_scope()
        if tracing_scope is None:
            return self.run(*args, **kwargs)
        else:
            kwargs["backend"] = tracing_scope.tracer_state.backend
            return self.trace(*args, **kwargs)


class SglExpr:
    node_ct = 0

    def __init__(self):
        self.node_id = SglExpr.node_ct
        self.prev_node = None
        self.pid = None
        SglExpr.node_ct += 1

    def __add__(self, other):
        if isinstance(other, str):
            other = SglConstantText(other)
        assert isinstance(other, SglExpr)

        return self.concatenate_ir(self, other)

    def __radd__(self, other):
        if isinstance(other, str):
            other = SglConstantText(other)
        assert isinstance(other, SglExpr), f"{other}"

        return self.concatenate_ir(other, self)

    def concatenate_ir(self, a, b):
        if isinstance(a, SglExprList):
            if isinstance(b, SglExprList):
                return SglExprList(a.expr_list + b.expr_list)
            else:
                return SglExprList(a.expr_list + [b])
        elif isinstance(b, SglExprList):
            return SglExprList([a] + b.expr_list)

        return SglExprList([a, b])

    def print_graph_dfs(self):
        ret = [""]
        visited = set()

        def dfs_print(x):
            if x is None or x in visited:
                return
            visited.add(x)

            # Print dependency
            if x.prev_node is not None:
                dfs_print(x.prev_node)

            if isinstance(x, SglExprList):
                for y in x.expr_list:
                    dfs_print(y)
            # elif isinstance(x, SglRole):
            #    dfs_print(x.expr)
            elif isinstance(x, SglVariable):
                dfs_print(x.source)

            # Print the node itself
            if isinstance(x, (SglFork, SglGetForkItem)):
                ret[0] += f"%{x.node_id} = {x}\n"
            else:
                if x.prev_node is not None:
                    ret[0] += (
                        f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
                    )
                else:
                    ret[0] += f"%{x.node_id} = " + str(x) + "\n"

        dfs_print(self)
        return ret[0]


class SglExprList(SglExpr):
    def __init__(self, expr_list: List[SglExpr]):
        super().__init__()
        self.expr_list = expr_list

    def __repr__(self):
        return f"ExprList({self.expr_list})"


class SglArgument(SglExpr):
    def __init__(self, name: str, value: str):
        super().__init__()
        self.name = name
        self.value = value

    def __repr__(self):
        return f"Argument(name={self.name}, value={repr(self.value)})"

    def __len__(self):
        return len(self.value)

    def __getitem__(self, i):
        return self.value[i]

    def __int__(self):
        return self.value

    def __bool__(self):
        return self.value

    def __format__(self, *args):
        raise TypeError(
            "Cannot put argument inside a f-string. "
            "This is not compatible with the tracer. "
        )


class SglImage(SglExpr):
385
    def __init__(self, path: str):
Lianmin Zheng's avatar
Lianmin Zheng committed
386
387
388
389
390
391
        self.path = path

    def __repr__(self) -> str:
        return f"SglImage({self.path})"


Yuanhan Zhang's avatar
Yuanhan Zhang committed
392
class SglVideo(SglExpr):
393
    def __init__(self, path: str, num_frames: int):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
394
395
396
397
398
399
400
        self.path = path
        self.num_frames = num_frames

    def __repr__(self) -> str:
        return f"SglVideo({self.path}, {self.num_frames})"


Lianmin Zheng's avatar
Lianmin Zheng committed
401
402
403
class SglGen(SglExpr):
    def __init__(
        self,
404
405
406
        name: Optional[str] = None,
        max_new_tokens: Optional[int] = None,
        stop: Optional[Union[str, List[str]]] = None,
407
        stop_token_ids: Optional[List[int]] = None,
408
409
410
411
412
413
414
415
416
417
418
419
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        ignore_eos: Optional[bool] = None,
        return_logprob: Optional[bool] = None,
        logprob_start_len: Optional[int] = None,
        top_logprobs_num: Optional[int] = None,
        return_text_in_logprobs: Optional[bool] = None,
        dtype: Optional[type] = None,
        regex: Optional[str] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
420
    ):
421
        """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
Lianmin Zheng's avatar
Lianmin Zheng committed
422
423
        super().__init__()
        self.name = name
424
        self.sampling_params = SglSamplingParams(
Lianmin Zheng's avatar
Lianmin Zheng committed
425
426
            max_new_tokens=max_new_tokens,
            stop=stop,
427
            stop_token_ids=stop_token_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
428
429
430
431
432
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
433
            ignore_eos=ignore_eos,
434
435
436
437
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            return_text_in_logprobs=return_text_in_logprobs,
Lianmin Zheng's avatar
Lianmin Zheng committed
438
439
440
441
442
443
444
445
446
            dtype=dtype,
            regex=regex,
        )

    def __repr__(self):
        return f"Gen('{self.name}')"


class SglConstantText(SglExpr):
447
    def __init__(self, value: str):
Lianmin Zheng's avatar
Lianmin Zheng committed
448
449
450
451
452
453
454
455
        super().__init__()
        self.value = value

    def __repr__(self):
        return f"Constant({repr(self.value)})"


class SglRoleBegin(SglExpr):
456
    def __init__(self, role: str):
Lianmin Zheng's avatar
Lianmin Zheng committed
457
458
459
460
461
462
463
464
        super().__init__()
        self.role = role

    def __repr__(self):
        return f"RoleBegin({self.role})"


class SglRoleEnd(SglExpr):
465
    def __init__(self, role: str):
Lianmin Zheng's avatar
Lianmin Zheng committed
466
467
468
469
470
471
472
473
        super().__init__()
        self.role = role

    def __repr__(self):
        return f"RoleEnd({self.role})"


class SglSelect(SglExpr):
474
475
476
477
478
479
480
481

    def __init__(
        self,
        name: str,
        choices: List[str],
        temperature: float,
        choices_method: ChoicesSamplingMethod,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
482
483
484
485
        super().__init__()
        self.name = name
        self.choices = choices
        self.temperature = temperature
486
        self.choices_method = choices_method
Lianmin Zheng's avatar
Lianmin Zheng committed
487
488

    def __repr__(self):
489
        return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
Lianmin Zheng's avatar
Lianmin Zheng committed
490
491
492


class SglFork(SglExpr):
493
    def __init__(self, number: int, position_ids_offset=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
496
497
498
499
500
501
502
503
504
505
        super().__init__()
        self.number = number
        self.position_ids_offset = position_ids_offset

    def __repr__(self):
        return (
            f"Fork(%{self.prev_node.node_id}, number={self.number}, "
            f"position_ids_offset={self.position_ids_offset})"
        )


class SglGetForkItem(SglExpr):
506
    def __init__(self, index: int):
Lianmin Zheng's avatar
Lianmin Zheng committed
507
508
509
510
511
512
513
514
        super().__init__()
        self.index = index

    def __repr__(self):
        return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"


class SglVariable(SglExpr):
515
    def __init__(self, name: str, source):
Lianmin Zheng's avatar
Lianmin Zheng committed
516
517
518
519
520
521
522
523
524
        super().__init__()
        self.name = name
        self.source = source

    def __repr__(self):
        return f"Variable('{self.name}', source=%{self.source.node_id})"


class SglVarScopeBegin(SglExpr):
525
    def __init__(self, name: str):
Lianmin Zheng's avatar
Lianmin Zheng committed
526
527
528
529
530
531
532
533
        super().__init__()
        self.name = name

    def __repr__(self):
        return f"VarScopeBegin('{self.name}')"


class SglVarScopeEnd(SglExpr):
534
    def __init__(self, name: str):
Lianmin Zheng's avatar
Lianmin Zheng committed
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        super().__init__()
        self.name = name

    def __repr__(self):
        return f"VarScopeEnd('{self.name}')"


class SglConcateAndAppend(SglExpr):
    def __init__(self, states):
        super().__init__()
        self.states = states

    def __repr__(self):
        return f"ConcatenateAndAppend('{self.states}')"


class SglCommitLazy(SglExpr):
    def __init__(self):
        super().__init__()

    def __repr__(self):
556
        return "CommitLazy()"