Unverified Commit 8f6274c8 authored by ObjectNotFound's avatar ObjectNotFound Committed by GitHub
Browse files

Add role documentation, add system begin & end tokens (#793)

parent 325a06c2
...@@ -433,6 +433,24 @@ for out in state.text_iter(): ...@@ -433,6 +433,24 @@ for out in state.text_iter():
print(out, end="", flush=True) print(out, end="", flush=True)
``` ```
#### Roles
Use `sgl.system``sgl.user` and `sgl.assistant` to set roles when using Chat models. You can also define more complex role prompts using begin and end tokens.
```python
@sgl.function
def chat_example(s):
s += sgl.system("You are a helpful assistant.")
# Same as: s += s.system("You are a helpful assistant.")
with s.user():
s += "Question: What is the capital of France?"
s += sgl.assistant_begin()
s += "Answer: " + sgl.gen(max_tokens=100, stop="\n")
s += sgl.assistant_end()
```
#### Tips and Implementation Details #### Tips and Implementation Details
- The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. - The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability.
- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. - The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`.
......
...@@ -14,6 +14,8 @@ from sglang.api import ( ...@@ -14,6 +14,8 @@ from sglang.api import (
select, select,
set_default_backend, set_default_backend,
system, system,
system_begin,
system_end,
user, user,
user_begin, user_begin,
user_end, user_end,
...@@ -60,4 +62,6 @@ __all__ = [ ...@@ -60,4 +62,6 @@ __all__ = [
"user_end", "user_end",
"assistant_begin", "assistant_begin",
"assistant_end", "assistant_end",
"system_begin",
"system_end",
] ]
...@@ -210,6 +210,14 @@ def assistant(expr: Optional[SglExpr] = None): ...@@ -210,6 +210,14 @@ def assistant(expr: Optional[SglExpr] = None):
return _role_common("assistant", expr) return _role_common("assistant", expr)
def system_begin():
return SglRoleBegin("system")
def system_end():
return SglRoleEnd("system")
def user_begin(): def user_begin():
return SglRoleBegin("user") return SglRoleBegin("user")
......
...@@ -705,9 +705,9 @@ class ProgramState: ...@@ -705,9 +705,9 @@ class ProgramState:
def _role_common(self, name: str, expr: Optional[SglExpr] = None): def _role_common(self, name: str, expr: Optional[SglExpr] = None):
if expr is not None: if expr is not None:
self.stream_executor.submit( role_expr = SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)]) self.stream_executor.submit(role_expr)
) return role_expr
else: else:
@contextmanager @contextmanager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment