tracer.py 7.98 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""Tracing a program."""

import uuid
from typing import Any, Dict, List, Optional

from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
    SglArgument,
    SglConstantText,
    SglExpr,
    SglExprList,
    SglFork,
    SglGen,
    SglGetForkItem,
    SglRoleBegin,
    SglRoleEnd,
    SglSelect,
    SglVariable,
    SglVarScopeBegin,
    SglVarScopeEnd,
)


class StopTracing(Exception):
    pass


def extract_prefix_by_tracing(program, backend):
    # Create dummy arguments
    dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
    arguments = dummy_arguments
    arguments.update(program.bind_arguments)

    # Trace
    tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
    try:
        with TracingScope(tracer):
            tracer.ret_value = program.func(tracer, **arguments)
    except (StopTracing, TypeError, AttributeError):
        # Some exceptions may not be caught
        pass

    # Run and cache prefix
    prefix = ""
    for expr in tracer.flatten_nodes():
        if isinstance(expr, SglConstantText):
            prefix += expr.value
        else:
            break
    return prefix


def trace_program(program, arguments, backend):
    # Create dummy backend
    if backend is None:
        backend = BaseBackend()

    # Create dummy arguments
    dummy_arguments = {
        name: SglArgument(name, None)
        for name in program.arg_names
        if name not in arguments
    }
    arguments.update(dummy_arguments)
    arguments.update(program.bind_arguments)

    # Trace
    tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
    with TracingScope(tracer):
        tracer.ret_value = program.func(tracer, **arguments)
    return tracer


class TracerProgramState(ProgramState):
    def __init__(self, backend, arguments, only_trace_prefix):
        self.pid = uuid.uuid4().hex
        self.backend = backend
        self.arguments: Dict[str, Any] = arguments
        self.only_trace_prefix = only_trace_prefix

        if hasattr(backend, "endpoint"):
            self.backend = backend.endpoint

        self.nodes = []
        self.last_node = None
        self.variables = {}
        self.ret_value = None

        # For completion

        # For chat
        self.messages_ = []
        self.cur_role = None
        self.chat_template = self.backend.get_chat_template()

        # For multi states
        self.child_states = []

        cur_scope = TracingScope.get_current_scope()
        if cur_scope is not None:
            cur_scope.add_child_state(self)

    ##################################
    ########### Public API ###########
    ##################################

    def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
        assert size >= 1

        if self.only_trace_prefix:
            raise StopTracing()

        fork_node = SglFork(size)
        fork_node.prev_node = self.last_node

        states = [
            TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
            for _ in range(size)
        ]

        for i in range(size):
            node = SglGetForkItem(i)
            node.prev_node = fork_node
            states[i].last_node = node
            states[i].variables = dict(self.variables)
            states[i].messages_ = list(self.messages_)
            states[i].cur_role = self.cur_role
            states[i].chat_template = self.chat_template

        state_group = ProgramStateGroup(states, self)

        return state_group

    ##################################
    ########## Internal API ##########
    ##################################

    def _append_node(self, other: SglExpr):
        self.nodes.append(other)
        other.prev_node = self.last_node
        self.last_node = other

    def _execute(self, other: SglExpr):
        if isinstance(other, str):
            other = SglConstantText(other)

        other.pid = self.pid

        if isinstance(other, SglConstantText):
            self._execute_fill(other)
        elif isinstance(other, SglGen):
            self._execute_gen(other)
        elif isinstance(other, SglSelect):
            self._execute_select(other)
        elif isinstance(other, SglExprList):
            for x in other.expr_list:
                self._execute(x)
        elif isinstance(other, SglRoleBegin):
            self._execute_role_begin(other)
        elif isinstance(other, SglRoleEnd):
            self._execute_role_end(other)
        elif isinstance(other, SglVarScopeBegin):
            self._execute_var_scope_begin(other)
        elif isinstance(other, SglVarScopeEnd):
            self._execute_var_scope_end(other)
        else:
            if self.only_trace_prefix:
                raise StopTracing()
            else:
                self._append_node(other)

        return self

    def __iadd__(self, other):
        self._execute(other)
        return self

    def _execute_fill(self, expr: SglConstantText):
        if isinstance(expr, str):
            expr = SglConstantText(expr)
        self._append_node(expr)

    def _execute_gen(self, expr: SglGen):
        name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
        new_node = SglVariable(name, source=expr)
        self.variables[name] = new_node
        self._append_node(expr)

    def _execute_select(self, expr: SglSelect):
        name = (
            expr.name if expr.name is not None else "select_" + str(len(self.variables))
        )
        new_node = SglVariable(name, source=expr)
        self.variables[name] = new_node
        self._append_node(expr)

    def _execute_role_begin(self, expr: SglRoleBegin):
        assert self.cur_role is None, "Nested roles are not allowed."

        if len(self.messages_) == 0 and expr.role != "system":
            # Insert default system message
            default_system = self.chat_template.default_system_prompt
            if default_system:
                self._execute_role_begin(SglRoleBegin("system"))
                self._execute_fill(default_system)
                self._execute_role_end(SglRoleEnd("system"))

        self.cur_role = expr.role

        prefix, suffix = self.chat_template.get_prefix_and_suffix(
            expr.role, self.messages_
        )

        self._execute_fill(prefix)

    def _execute_role_end(self, expr: SglRoleEnd):
        prefix, suffix = self.chat_template.get_prefix_and_suffix(
            expr.role, self.messages_
        )

        self._execute_fill(suffix)

        self.messages_.append({"role": expr.role, "content": ""})

        self.cur_role = None

    def _execute_var_scope_end(self, expr: SglVarScopeEnd):
        new_node = SglVariable(expr.name, source=self.last_node)
        self.variables[expr.name] = new_node

    def get_var(self, name):
        ret = self.arguments.get(name, None)
        if ret is not None:
            return ret

        v = self.variables[name]
        return SglVariable(v.name, v.source)

    def flatten_nodes(self):
        def traverse(cur):
            if isinstance(cur, SglExprList):
                for child in cur.expr_list:
                    traverse(child)
            else:
                ret.append(cur)

        ret = []
        for x in self.nodes:
            traverse(x)
        return ret

    def __del__(self):
        pass


class TracingScope:
    cur_scope = None

    def __init__(self, tracer_state: TracerProgramState):
        self.tracer_state = tracer_state
        self.last_scope = TracingScope.cur_scope

    def __enter__(self):
        TracingScope.cur_scope = self
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        TracingScope.cur_scope = self.last_scope

    @staticmethod
    def get_current_scope():
        return TracingScope.cur_scope

    def add_child_state(self, state: TracerProgramState):
        cur_scope = self
        while cur_scope is not None:
            cur_scope.tracer_state.child_states.append(state)
            cur_scope = cur_scope.last_scope