Unverified Commit 5dc55a5f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Handle truncation errors (#436)

parent 4231a42f
......@@ -369,7 +369,7 @@
"\n",
"The evaluation scores appear in the bottom right of the logs (screenshot below). Note, that there is no score for `answer_matches_target_llm_grader` and `percent_target_supported_by_context` as these evals are automatically skipped if the target answer is not provided.\n",
"\n",
"![Fixed Max. Tokens](./images/rag/max-tokens-fixed-rag-trace.png)"
"![Fixed Max. Tokens](max-tokens-fixed-rag-trace.png)"
],
"metadata": {
"collapsed": false
......
......@@ -16,7 +16,7 @@ class GlobalConfig:
# Optimization configs
self.eager_fill_image = False
self.enable_prefix_sharing = True
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True
self.enable_parallel_decoding = True
......
......@@ -86,9 +86,9 @@ def run_program_batch(
if hasattr(backend, "endpoint"):
backend = backend.endpoint
# Extract prefix by tracing and cache it
if len(batch_arguments) > 1:
pin_program(program, backend)
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
cache_program(program, backend)
# Run all programs
if num_threads == "auto":
......@@ -154,21 +154,12 @@ def run_program_batch(
return rets
def pin_program(program, backend):
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
# TODO: handle multiple backends
from sglang.lang.tracer import extract_prefix_by_tracing
def cache_program(program, backend):
from sglang.lang.tracer import extract_prefix_by_tracing
prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64:
prefix_rid = backend.cache_prefix(prefix)
program.pin_prefix_rid = prefix_rid
return prefix_rid
return None
def unpin_program(program, backend):
pass
prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64:
backend.cache_prefix(prefix)
class StreamExecutor:
......@@ -322,7 +313,7 @@ class StreamExecutor:
try:
self._execute(expr)
except Exception as e:
print(f"Error in stream_executor: {get_exception_traceback()}")
# print(f"Error in stream_executor: {get_exception_traceback()}")
error = e
break
self.queue.task_done()
......@@ -702,9 +693,10 @@ class ProgramState:
return self.stream_executor.messages()
def sync(self):
ret = self.stream_executor.sync()
self.error = self.stream_executor.error
return ret
return self.stream_executor.sync()
def error(self):
return self.stream_executor.error
def text_iter(self, var_name: Optional[str] = None):
if self.stream_executor.stream:
......
......@@ -193,17 +193,11 @@ class SglFunction:
backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend)
def pin(self, backend=None):
from sglang.lang.interpreter import pin_program
def cache(self, backend=None):
from sglang.lang.interpreter import cache_program
backend = backend or global_config.default_backend
return pin_program(self, backend)
def unpin(self, backend=None):
from sglang.lang.interpreter import unpin_program
backend = backend or global_config.default_backend
return unpin_program(self, backend)
return cache_program(self, backend)
def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func
......
......@@ -20,6 +20,16 @@ class FinishReason(IntEnum):
LENGTH = auto()
STOP_STR = auto()
def to_str(self):
if self == FinishReason.EOS_TOKEN:
return None
elif self == FinishReason.LENGTH:
return "length"
elif self == FinishReason.STOP_STR:
return "stop"
else:
raise ValueError(f"Invalid finish reason: {self}")
class Req:
def __init__(self, rid, input_text, input_ids):
......
......@@ -612,7 +612,7 @@ class ModelRpcServer:
+ len(req.output_ids)
- req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
"finish_reason": req.finish_reason.to_str(),
"hit_stop_str": req.hit_stop_str,
}
if req.return_logprob:
......
......@@ -98,7 +98,6 @@ class TokenizerManager:
self.hf_config = get_config(
self.model_path, trust_remote_code=server_args.trust_remote_code
)
self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path):
......@@ -156,6 +155,12 @@ class TokenizerManager:
else:
input_ids = obj.input_ids
if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)"
)
sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
......
......@@ -20,7 +20,7 @@ import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi.responses import Response, StreamingResponse
from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache
......@@ -90,8 +90,11 @@ async def generate_request(obj: GenerateReqInput):
return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret
try:
ret = await tokenizer_manager.generate_request(obj).__anext__()
return ret
except ValueError as e:
return JSONResponse({"error": str(e)}, status_code=400)
@app.post("/v1/completions")
......
......@@ -29,7 +29,7 @@ class TestBind(unittest.TestCase):
tracer = few_shot_qa_2.trace()
print(tracer.last_node.print_graph_dfs() + "\n")
def test_pin(self):
def test_cache(self):
@sgl.function
def few_shot_qa(s, prompt, question):
s += prompt
......@@ -41,8 +41,7 @@ class TestBind(unittest.TestCase):
few_shot_qa_2 = few_shot_qa.bind(
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
)
few_shot_qa_2.pin(self.backend)
few_shot_qa_2.unpin(self.backend)
few_shot_qa_2.cache(self.backend)
if __name__ == "__main__":
......@@ -50,4 +49,4 @@ if __name__ == "__main__":
# t = TestBind()
# t.setUp()
# t.test_pin()
# t.test_cache()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment