ir.py 14.3 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
"""The intermediate representation."""

import dataclasses
import inspect
5
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
9
from typing import List, Optional, Union

from sglang.global_config import global_config

10
11
12
13
14
REGEX_INT = r"[-+]?[0-9]+"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
REGEX_BOOL = r"(True|False)"
REGEX_STRING = r"\"[\w\d\s]*\""  # bugs with regex r"\".*\"" in interegular pkg

Lianmin Zheng's avatar
Lianmin Zheng committed
15
16

@dataclasses.dataclass
17
class SglSamplingParams:
Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
20
21
22
23
24
    max_new_tokens: int = 16
    stop: Union[str, List[str]] = ()
    temperature: float = 1.0
    top_p: float = 1.0
    top_k: int = -1  # -1 means disable
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
25
    ignore_eos: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
29
30
31

    # for constrained generation, not included in to_xxx_kwargs
    dtype: Optional[str] = None
    regex: Optional[str] = None

    def clone(self):
32
        return SglSamplingParams(
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
36
37
38
39
40
41
42
43
            self.max_new_tokens,
            self.stop,
            self.temperature,
            self.top_p,
            self.top_k,
            self.frequency_penalty,
            self.presence_penalty,
        )

    def to_openai_kwargs(self):
        # OpenAI does not support top_k, so we drop it here
44
45
        if self.regex is not None:
            warnings.warn("Regular expression is not supported in the OpenAI backend.")
Lianmin Zheng's avatar
Lianmin Zheng committed
46
47
48
49
50
51
52
53
54
        return {
            "max_tokens": self.max_new_tokens,
            "stop": self.stop or None,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "frequency_penalty": self.frequency_penalty,
            "presence_penalty": self.presence_penalty,
        }

55
56
    def to_vertexai_kwargs(self):
        if self.regex is not None:
57
58
59
            warnings.warn(
                "Regular expression is not supported in the VertexAI backend."
            )
shiyi.c_98's avatar
shiyi.c_98 committed
60
61
62
63
64
65
66
67
68
        return {
            "candidate_count": 1,
            "max_output_tokens": self.max_new_tokens,
            "stop_sequences": self.stop,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k if self.top_k > 0 else None,
        }

Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
    def to_anthropic_kwargs(self):
        # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
71
        if self.regex is not None:
72
73
74
            warnings.warn(
                "Regular expression is not supported in the Anthropic backend."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
75
        return {
76
            "max_tokens": self.max_new_tokens,
77
78
79
            "stop_sequences": (
                self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
            ),
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82
83
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
        }
84

胡译文's avatar
胡译文 committed
85
86
    def to_litellm_kwargs(self):
        if self.regex is not None:
87
            warnings.warn("Regular expression is not supported in the LiteLLM backend.")
胡译文's avatar
胡译文 committed
88
89
90
91
92
93
94
95
96
        return {
            "max_tokens": self.max_new_tokens,
            "stop": self.stop or None,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "frequency_penalty": self.frequency_penalty,
            "presence_penalty": self.presence_penalty,
        }
Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
99
100
101
102
103
104
105
106

    def to_srt_kwargs(self):
        return {
            "max_new_tokens": self.max_new_tokens,
            "stop": self.stop,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "frequency_penalty": self.frequency_penalty,
            "presence_penalty": self.presence_penalty,
107
            "ignore_eos": self.ignore_eos,
Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
110
111
112
            "regex": self.regex,
        }


class SglFunction:
113
    def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        self.func = func
115
        self.num_api_spec_tokens = num_api_spec_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
120
121
122
        self.bind_arguments = bind_arguments or {}
        self.pin_prefix_rid = None

        # Parse arguments
        argspec = inspect.getfullargspec(func)
        assert argspec.args[0] == "s", 'The first argument must be "s"'
        self.arg_names = argspec.args[1:]
123
        self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    def bind(self, **kwargs):
        assert all(key in self.arg_names for key in kwargs)

        new_bind_dict = {**self.bind_arguments, **kwargs}
        return SglFunction(self.func, bind_arguments=new_bind_dict)

    def run(
        self,
        *args,
        max_new_tokens: int = 16,
        stop: Union[str, List[str]] = (),
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
141
        ignore_eos: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
144
145
146
147
        stream: bool = False,
        backend=None,
        **kwargs,
    ):
        from sglang.lang.interpreter import run_program

148
        default_sampling_para = SglSamplingParams(
Lianmin Zheng's avatar
Lianmin Zheng committed
149
150
151
152
153
154
155
            max_new_tokens=max_new_tokens,
            stop=stop,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
156
            ignore_eos=ignore_eos,
Lianmin Zheng's avatar
Lianmin Zheng committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        )
        backend = backend or global_config.default_backend
        return run_program(self, backend, args, kwargs, default_sampling_para, stream)

    def run_batch(
        self,
        batch_kwargs,
        *,
        max_new_tokens: int = 16,
        stop: Union[str, List[str]] = (),
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
172
        ignore_eos: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
173
174
175
176
177
178
179
180
181
        backend=None,
        num_threads: Union[str, int] = "auto",
        progress_bar: bool = False,
    ):
        from sglang.lang.interpreter import run_program_batch

        assert isinstance(batch_kwargs, (list, tuple))
        if len(batch_kwargs) == 0:
            return []
182
183
184
185
186
187
        if not isinstance(batch_kwargs[0], dict):
            num_programs = len(batch_kwargs)
            # change the list of argument values to dict of arg_name -> arg_value
            batch_kwargs = [
                {self.arg_names[i]: v for i, v in enumerate(arg_values)}
                for arg_values in batch_kwargs
Ying Sheng's avatar
Ying Sheng committed
188
189
190
191
                if isinstance(arg_values, (list, tuple))
                and len(self.arg_names) - len(self.arg_defaults)
                <= len(arg_values)
                <= len(self.arg_names)
192
193
194
195
            ]
            # Ensure to raise an exception if the number of arguments mismatch
            if len(batch_kwargs) != num_programs:
                raise Exception("Given arguments mismatch the SGL function signature")
Lianmin Zheng's avatar
Lianmin Zheng committed
196

197
        default_sampling_para = SglSamplingParams(
Lianmin Zheng's avatar
Lianmin Zheng committed
198
199
200
201
202
203
204
            max_new_tokens=max_new_tokens,
            stop=stop,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
205
            ignore_eos=ignore_eos,
Lianmin Zheng's avatar
Lianmin Zheng committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        )
        backend = backend or global_config.default_backend
        return run_program_batch(
            self,
            backend,
            batch_kwargs,
            default_sampling_para,
            num_threads,
            progress_bar,
        )

    def trace(self, *, backend=None, **kwargs):
        from sglang.lang.tracer import trace_program

        backend = backend or global_config.default_backend
        return trace_program(self, kwargs, backend)

223
224
    def cache(self, backend=None):
        from sglang.lang.interpreter import cache_program
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226

        backend = backend or global_config.default_backend
227
        return cache_program(self, backend)
Lianmin Zheng's avatar
Lianmin Zheng committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    def compile(self, *, backend=None):
        from sglang.lang.compiler import compile_func

        return compile_func(self, backend)

    def __call__(self, *args, **kwargs):
        from sglang.lang.tracer import TracingScope

        tracing_scope = TracingScope.get_current_scope()
        if tracing_scope is None:
            return self.run(*args, **kwargs)
        else:
            kwargs["backend"] = tracing_scope.tracer_state.backend
            return self.trace(*args, **kwargs)


class SglExpr:
    node_ct = 0

    def __init__(self):
        self.node_id = SglExpr.node_ct
        self.prev_node = None
        self.pid = None
        SglExpr.node_ct += 1

    def __add__(self, other):
        if isinstance(other, str):
            other = SglConstantText(other)
        assert isinstance(other, SglExpr)

        return self.concatenate_ir(self, other)

    def __radd__(self, other):
        if isinstance(other, str):
            other = SglConstantText(other)
        assert isinstance(other, SglExpr), f"{other}"

        return self.concatenate_ir(other, self)

    def concatenate_ir(self, a, b):
        if isinstance(a, SglExprList):
            if isinstance(b, SglExprList):
                return SglExprList(a.expr_list + b.expr_list)
            else:
                return SglExprList(a.expr_list + [b])
        elif isinstance(b, SglExprList):
            return SglExprList([a] + b.expr_list)

        return SglExprList([a, b])

    def print_graph_dfs(self):
        ret = [""]
        visited = set()

        def dfs_print(x):
            if x is None or x in visited:
                return
            visited.add(x)

            # Print dependency
            if x.prev_node is not None:
                dfs_print(x.prev_node)

            if isinstance(x, SglExprList):
                for y in x.expr_list:
                    dfs_print(y)
            # elif isinstance(x, SglRole):
            #    dfs_print(x.expr)
            elif isinstance(x, SglVariable):
                dfs_print(x.source)

            # Print the node itself
            if isinstance(x, (SglFork, SglGetForkItem)):
                ret[0] += f"%{x.node_id} = {x}\n"
            else:
                if x.prev_node is not None:
                    ret[0] += (
                        f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
                    )
                else:
                    ret[0] += f"%{x.node_id} = " + str(x) + "\n"

        dfs_print(self)
        return ret[0]


class SglExprList(SglExpr):
    def __init__(self, expr_list: List[SglExpr]):
        super().__init__()
        self.expr_list = expr_list

    def __repr__(self):
        return f"ExprList({self.expr_list})"


class SglArgument(SglExpr):
    def __init__(self, name: str, value: str):
        super().__init__()
        self.name = name
        self.value = value

    def __repr__(self):
        return f"Argument(name={self.name}, value={repr(self.value)})"

    def __len__(self):
        return len(self.value)

    def __getitem__(self, i):
        return self.value[i]

    def __int__(self):
        return self.value

    def __bool__(self):
        return self.value

    def __format__(self, *args):
        raise TypeError(
            "Cannot put argument inside a f-string. "
            "This is not compatible with the tracer. "
        )


class SglImage(SglExpr):
    def __init__(self, path):
        self.path = path

    def __repr__(self) -> str:
        return f"SglImage({self.path})"


Yuanhan Zhang's avatar
Yuanhan Zhang committed
360
361
362
363
364
365
366
367
368
class SglVideo(SglExpr):
    def __init__(self, path, num_frames):
        self.path = path
        self.num_frames = num_frames

    def __repr__(self) -> str:
        return f"SglVideo({self.path}, {self.num_frames})"


Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
371
372
373
374
375
376
377
378
379
class SglGen(SglExpr):
    def __init__(
        self,
        name,
        max_new_tokens,
        stop,
        temperature,
        top_p,
        top_k,
        frequency_penalty,
        presence_penalty,
380
        ignore_eos,
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
383
384
385
        dtype,
        regex,
    ):
        super().__init__()
        self.name = name
386
        self.sampling_params = SglSamplingParams(
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
389
390
391
392
393
            max_new_tokens=max_new_tokens,
            stop=stop,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
394
            ignore_eos=ignore_eos,
Lianmin Zheng's avatar
Lianmin Zheng committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            dtype=dtype,
            regex=regex,
        )

    def __repr__(self):
        return f"Gen('{self.name}')"


class SglConstantText(SglExpr):
    def __init__(self, value):
        super().__init__()
        self.value = value

    def __repr__(self):
        return f"Constant({repr(self.value)})"


class SglRoleBegin(SglExpr):
    def __init__(self, role):
        super().__init__()
        self.role = role

    def __repr__(self):
        return f"RoleBegin({self.role})"


class SglRoleEnd(SglExpr):
    def __init__(self, role):
        super().__init__()
        self.role = role

    def __repr__(self):
        return f"RoleEnd({self.role})"


class SglSelect(SglExpr):
    def __init__(self, name, choices, temperature):
        super().__init__()
        self.name = name
        self.choices = choices
        self.temperature = temperature

    def __repr__(self):
        return f"Select({self.name}, choices={self.choices})"


class SglFork(SglExpr):
    def __init__(self, number, position_ids_offset=None):
        super().__init__()
        self.number = number
        self.position_ids_offset = position_ids_offset

    def __repr__(self):
        return (
            f"Fork(%{self.prev_node.node_id}, number={self.number}, "
            f"position_ids_offset={self.position_ids_offset})"
        )


class SglGetForkItem(SglExpr):
    def __init__(self, index):
        super().__init__()
        self.index = index

    def __repr__(self):
        return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"


class SglVariable(SglExpr):
    def __init__(self, name, source):
        super().__init__()
        self.name = name
        self.source = source

    def __repr__(self):
        return f"Variable('{self.name}', source=%{self.source.node_id})"


class SglVarScopeBegin(SglExpr):
    def __init__(self, name):
        super().__init__()
        self.name = name

    def __repr__(self):
        return f"VarScopeBegin('{self.name}')"


class SglVarScopeEnd(SglExpr):
    def __init__(self, name):
        super().__init__()
        self.name = name

    def __repr__(self):
        return f"VarScopeEnd('{self.name}')"


class SglConcateAndAppend(SglExpr):
    def __init__(self, states):
        super().__init__()
        self.states = states

    def __repr__(self):
        return f"ConcatenateAndAppend('{self.states}')"


class SglCommitLazy(SglExpr):
    def __init__(self):
        super().__init__()

    def __repr__(self):
505
        return f"CommitLazy()"